fix sampling in multi GPU training

This commit is contained in:
Kohya S
2023-07-15 11:21:14 +09:00
parent 9de357e373
commit 81fa54837f
3 changed files with 33 additions and 44 deletions

View File

@@ -446,9 +446,7 @@ def prepare_controlnet_image(
for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
@@ -479,6 +477,7 @@ def prepare_controlnet_image(
return image
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
@@ -889,8 +888,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
mask = None
if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
controlnet_image = prepare_controlnet_image(
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -930,8 +930,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
guess_mode=False,
return_dict=False,
)
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
@@ -956,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
if is_cancelled_callback is not None and is_cancelled_callback():
return None
return latents
def latents_to_image(self, latents):
# 9. Post-processing
image = self.decode_latents(latents)
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
image = self.decode_latents(latents.to(self.vae.dtype))
image = self.numpy_to_pil(image)
return image
def text2img(
self,