mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix sampling in training with mutiple gpus ref #989
This commit is contained in:
@@ -4603,7 +4603,7 @@ def line_to_prompt_dict(line: str) -> dict:
|
|||||||
|
|
||||||
def sample_images_common(
|
def sample_images_common(
|
||||||
pipe_class,
|
pipe_class,
|
||||||
accelerator,
|
accelerator: Accelerator,
|
||||||
args: argparse.Namespace,
|
args: argparse.Namespace,
|
||||||
epoch,
|
epoch,
|
||||||
steps,
|
steps,
|
||||||
@@ -4640,6 +4640,13 @@ def sample_images_common(
|
|||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(device)
|
vae.to(device)
|
||||||
|
|
||||||
|
# unwrap unet and text_encoder(s)
|
||||||
|
unet = accelerator.unwrap_model(unet)
|
||||||
|
if isinstance(text_encoder, (list, tuple)):
|
||||||
|
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
|
||||||
|
else:
|
||||||
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
|
|
||||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user