fix NaN in sampling image

This commit is contained in:
Kohya S
2023-07-11 23:18:35 +09:00
parent 2e67d74df4
commit 814996b14f
2 changed files with 118 additions and 118 deletions

View File

@@ -922,7 +922,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
if up1 is not None:
uncond_pool = up1
dtype = text_embeddings_list[0].dtype
dtype = self.unet.dtype
# 4. Preprocess image and mask
if isinstance(image, PIL.Image.Image):

View File

@@ -3874,127 +3874,127 @@ def sample_images_common(
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
with torch.no_grad():
with accelerator.autocast():
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
# with accelerator.autocast():
for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
).images[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage
del pipeline