fix sampling in multi GPU training

This commit is contained in:
Kohya S
2023-07-15 11:21:14 +09:00
parent 9de357e373
commit 81fa54837f
3 changed files with 33 additions and 44 deletions

View File

@@ -3964,16 +3964,19 @@ def sample_images_common(
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
).images[0]
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
image = pipeline.latents_to_image(latents)[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"