mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix sampling in multi GPU training
This commit is contained in:
@@ -3964,16 +3964,19 @@ def sample_images_common(
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
controlnet=controlnet,
|
||||
controlnet_image=controlnet_image,
|
||||
).images[0]
|
||||
with accelerator.autocast():
|
||||
latents = pipeline(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=sample_steps,
|
||||
guidance_scale=scale,
|
||||
negative_prompt=negative_prompt,
|
||||
controlnet=controlnet,
|
||||
controlnet_image=controlnet_image,
|
||||
)
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
|
||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||
|
||||
Reference in New Issue
Block a user