mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sampling in multi GPU training
This commit is contained in:
@@ -446,9 +446,7 @@ def prepare_controlnet_image(
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize(
|
||||
(width, height), resample=PIL_INTERPOLATION["lanczos"]
|
||||
)
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
@@ -479,6 +477,7 @@ def prepare_controlnet_image(
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
||||
@@ -889,8 +888,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
mask = None
|
||||
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
|
||||
|
||||
controlnet_image = prepare_controlnet_image(
|
||||
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
||||
)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -930,8 +930,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
|
||||
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
|
||||
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
||||
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
||||
@@ -956,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
return latents
|
||||
|
||||
def latents_to_image(self, latents):
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 10. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return image, has_nsfw_concept
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
image = self.decode_latents(latents.to(self.vae.dtype))
|
||||
image = self.numpy_to_pil(image)
|
||||
return image
|
||||
|
||||
def text2img(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user