fix sampling in multi GPU training

This commit is contained in:
Kohya S
2023-07-15 11:21:14 +09:00
parent 9de357e373
commit 81fa54837f
3 changed files with 33 additions and 44 deletions

View File

@@ -446,9 +446,7 @@ def prepare_controlnet_image(
for image_ in image: for image_ in image:
image_ = image_.convert("RGB") image_ = image_.convert("RGB")
image_ = image_.resize( image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = np.array(image_) image_ = np.array(image_)
image_ = image_[None, :] image_ = image_[None, :]
images.append(image_) images.append(image_)
@@ -479,6 +477,7 @@ def prepare_controlnet_image(
return image return image
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r""" r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
@@ -889,8 +888,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
mask = None mask = None
if controlnet_image is not None: if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False) controlnet_image = prepare_controlnet_image(
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
)
# 5. set timesteps # 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device) self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -930,8 +930,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
guess_mode=False, guess_mode=False,
return_dict=False, return_dict=False,
) )
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
# predict the noise residual # predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
@@ -956,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
if is_cancelled_callback is not None and is_cancelled_callback(): if is_cancelled_callback is not None and is_cancelled_callback():
return None return None
return latents
def latents_to_image(self, latents):
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents) image = self.decode_latents(latents.to(self.vae.dtype))
# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return image
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def text2img( def text2img(
self, self,

View File

@@ -1027,20 +1027,13 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if is_cancelled_callback is not None and is_cancelled_callback(): if is_cancelled_callback is not None and is_cancelled_callback():
return None return None
return latents
def latents_to_image(self, latents):
# 9. Post-processing # 9. Post-processing
image = self.decode_latents(latents.to(torch.float32)) image = self.decode_latents(latents.to(self.vae.dtype))
# 10. Run safety checker
image, has_nsfw_concept = image, None # self.run_safety_checker(image, device, text_embeddings.dtype)
# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)
return image
if not return_dict:
return image, has_nsfw_concept
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
# copy from pil_utils.py # copy from pil_utils.py
def numpy_to_pil(self, images: np.ndarray) -> Image.Image: def numpy_to_pil(self, images: np.ndarray) -> Image.Image:

View File

@@ -3964,7 +3964,8 @@ def sample_images_common(
print(f"width: {width}") print(f"width: {width}")
print(f"sample_steps: {sample_steps}") print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}") print(f"scale: {scale}")
image = pipeline( with accelerator.autocast():
latents = pipeline(
prompt=prompt, prompt=prompt,
height=height, height=height,
width=width, width=width,
@@ -3973,7 +3974,9 @@ def sample_images_common(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
controlnet=controlnet, controlnet=controlnet,
controlnet_image=controlnet_image, controlnet_image=controlnet_image,
).images[0] )
image = pipeline.latents_to_image(latents)[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"