mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix sampling in multi GPU training
This commit is contained in:
@@ -446,9 +446,7 @@ def prepare_controlnet_image(
|
|||||||
|
|
||||||
for image_ in image:
|
for image_ in image:
|
||||||
image_ = image_.convert("RGB")
|
image_ = image_.convert("RGB")
|
||||||
image_ = image_.resize(
|
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||||
(width, height), resample=PIL_INTERPOLATION["lanczos"]
|
|
||||||
)
|
|
||||||
image_ = np.array(image_)
|
image_ = np.array(image_)
|
||||||
image_ = image_[None, :]
|
image_ = image_[None, :]
|
||||||
images.append(image_)
|
images.append(image_)
|
||||||
@@ -479,6 +477,7 @@ def prepare_controlnet_image(
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
||||||
@@ -889,8 +888,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
mask = None
|
mask = None
|
||||||
|
|
||||||
if controlnet_image is not None:
|
if controlnet_image is not None:
|
||||||
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
|
controlnet_image = prepare_controlnet_image(
|
||||||
|
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
||||||
|
)
|
||||||
|
|
||||||
# 5. set timesteps
|
# 5. set timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
@@ -930,8 +930,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
guess_mode=False,
|
guess_mode=False,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)
|
)
|
||||||
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
|
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
||||||
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
|
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
||||||
@@ -956,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|||||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def latents_to_image(self, latents):
|
||||||
# 9. Post-processing
|
# 9. Post-processing
|
||||||
image = self.decode_latents(latents)
|
image = self.decode_latents(latents.to(self.vae.dtype))
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
# 10. Run safety checker
|
return image
|
||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
|
||||||
|
|
||||||
# 11. Convert to PIL
|
|
||||||
if output_type == "pil":
|
|
||||||
image = self.numpy_to_pil(image)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return image, has_nsfw_concept
|
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
||||||
|
|
||||||
def text2img(
|
def text2img(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1005,7 +1005,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||||
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
@@ -1027,20 +1027,13 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def latents_to_image(self, latents):
|
||||||
# 9. Post-processing
|
# 9. Post-processing
|
||||||
image = self.decode_latents(latents.to(torch.float32))
|
image = self.decode_latents(latents.to(self.vae.dtype))
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
# 10. Run safety checker
|
return image
|
||||||
image, has_nsfw_concept = image, None # self.run_safety_checker(image, device, text_embeddings.dtype)
|
|
||||||
|
|
||||||
# 11. Convert to PIL
|
|
||||||
if output_type == "pil":
|
|
||||||
image = self.numpy_to_pil(image)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return image, has_nsfw_concept
|
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
||||||
|
|
||||||
# copy from pil_utils.py
|
# copy from pil_utils.py
|
||||||
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
|
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
|
||||||
|
|||||||
@@ -3964,16 +3964,19 @@ def sample_images_common(
|
|||||||
print(f"width: {width}")
|
print(f"width: {width}")
|
||||||
print(f"sample_steps: {sample_steps}")
|
print(f"sample_steps: {sample_steps}")
|
||||||
print(f"scale: {scale}")
|
print(f"scale: {scale}")
|
||||||
image = pipeline(
|
with accelerator.autocast():
|
||||||
prompt=prompt,
|
latents = pipeline(
|
||||||
height=height,
|
prompt=prompt,
|
||||||
width=width,
|
height=height,
|
||||||
num_inference_steps=sample_steps,
|
width=width,
|
||||||
guidance_scale=scale,
|
num_inference_steps=sample_steps,
|
||||||
negative_prompt=negative_prompt,
|
guidance_scale=scale,
|
||||||
controlnet=controlnet,
|
negative_prompt=negative_prompt,
|
||||||
controlnet_image=controlnet_image,
|
controlnet=controlnet,
|
||||||
).images[0]
|
controlnet_image=controlnet_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
image = pipeline.latents_to_image(latents)[0]
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||||
|
|||||||
Reference in New Issue
Block a user