mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix sampling in multi GPU training
This commit is contained in:
@@ -446,9 +446,7 @@ def prepare_controlnet_image(
|
||||
|
||||
for image_ in image:
|
||||
image_ = image_.convert("RGB")
|
||||
image_ = image_.resize(
|
||||
(width, height), resample=PIL_INTERPOLATION["lanczos"]
|
||||
)
|
||||
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
|
||||
image_ = np.array(image_)
|
||||
image_ = image_[None, :]
|
||||
images.append(image_)
|
||||
@@ -479,6 +477,7 @@ def prepare_controlnet_image(
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
|
||||
@@ -889,8 +888,9 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
mask = None
|
||||
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)
|
||||
|
||||
controlnet_image = prepare_controlnet_image(
|
||||
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
|
||||
)
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -930,8 +930,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
guess_mode=False,
|
||||
return_dict=False,
|
||||
)
|
||||
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
|
||||
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
|
||||
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
|
||||
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
|
||||
@@ -956,20 +956,13 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
return latents
|
||||
|
||||
def latents_to_image(self, latents):
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
# 10. Run safety checker
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return image, has_nsfw_concept
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
image = self.decode_latents(latents.to(self.vae.dtype))
|
||||
image = self.numpy_to_pil(image)
|
||||
return image
|
||||
|
||||
def text2img(
|
||||
self,
|
||||
|
||||
@@ -1005,7 +1005,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
|
||||
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
||||
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -1027,20 +1027,13 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
return latents
|
||||
|
||||
def latents_to_image(self, latents):
|
||||
# 9. Post-processing
|
||||
image = self.decode_latents(latents.to(torch.float32))
|
||||
|
||||
# 10. Run safety checker
|
||||
image, has_nsfw_concept = image, None # self.run_safety_checker(image, device, text_embeddings.dtype)
|
||||
|
||||
# 11. Convert to PIL
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return image, has_nsfw_concept
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
image = self.decode_latents(latents.to(self.vae.dtype))
|
||||
image = self.numpy_to_pil(image)
|
||||
return image
|
||||
|
||||
# copy from pil_utils.py
|
||||
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
|
||||
|
||||
@@ -3964,16 +3964,19 @@ def sample_images_common(
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
controlnet=controlnet,
|
||||
controlnet_image=controlnet_image,
|
||||
).images[0]
|
||||
with accelerator.autocast():
|
||||
latents = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
controlnet=controlnet,
|
||||
controlnet_image=controlnet_image,
|
||||
)
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
|
||||
Reference in New Issue
Block a user