fix for multi gpu training

This commit is contained in:
ddPn08
2023-03-03 00:21:18 +09:00
parent 8d5ba29363
commit 87846c043f
2 changed files with 30 additions and 17 deletions

View File

@@ -2294,6 +2294,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
with torch.no_grad():
with accelerator.autocast():
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == '#':
continue
@@ -2351,6 +2353,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())