From 81fa54837f538a2c9d294d1aa66a45421edeeaee Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 15 Jul 2023 11:21:14 +0900 Subject: [PATCH] fix sampling in multi GPU training --- library/lpw_stable_diffusion.py | 33 +++++++++++----------------- library/sdxl_lpw_stable_diffusion.py | 21 ++++++------------ library/train_util.py | 23 ++++++++++--------- 3 files changed, 33 insertions(+), 44 deletions(-) diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 2605c864..9dce91a7 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -446,9 +446,7 @@ def prepare_controlnet_image( for image_ in image: image_ = image_.convert("RGB") - image_ = image_.resize( - (width, height), resample=PIL_INTERPOLATION["lanczos"] - ) + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) image_ = np.array(image_) image_ = image_[None, :] images.append(image_) @@ -479,6 +477,7 @@ def prepare_controlnet_image( return image + class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing @@ -889,8 +888,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): mask = None if controlnet_image is not None: - controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False) - + controlnet_image = prepare_controlnet_image( + controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False + ) # 5. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -930,8 +930,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): guess_mode=False, return_dict=False, ) - unet_additional_args['down_block_additional_residuals'] = down_block_res_samples - unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample + unet_additional_args["down_block_additional_residuals"] = down_block_res_samples + unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample @@ -956,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): if is_cancelled_callback is not None and is_cancelled_callback(): return None + return latents + + def latents_to_image(self, latents): # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return image, has_nsfw_concept - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + image = self.decode_latents(latents.to(self.vae.dtype)) + image = self.numpy_to_pil(image) + return image def text2img( self, diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index a65a1d96..7f88469f 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -1005,7 +1005,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: # predict the noise residual noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding) - noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training + noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training # perform guidance if do_classifier_free_guidance: @@ -1027,20 +1027,13 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if is_cancelled_callback is not None and is_cancelled_callback(): return None + return latents + + def latents_to_image(self, latents): # 9. Post-processing - image = self.decode_latents(latents.to(torch.float32)) - - # 10. Run safety checker - image, has_nsfw_concept = image, None # self.run_safety_checker(image, device, text_embeddings.dtype) - - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - - if not return_dict: - return image, has_nsfw_concept - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + image = self.decode_latents(latents.to(self.vae.dtype)) + image = self.numpy_to_pil(image) + return image # copy from pil_utils.py def numpy_to_pil(self, images: np.ndarray) -> Image.Image: diff --git a/library/train_util.py b/library/train_util.py index 9438a189..15c58cc8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3964,16 +3964,19 @@ def sample_images_common( print(f"width: {width}") print(f"sample_steps: {sample_steps}") print(f"scale: {scale}") - image = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ).images[0] + with accelerator.autocast(): + latents = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ) + + image = pipeline.latents_to_image(latents)[0] ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"