mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix NaN in sampling image
This commit is contained in:
@@ -922,7 +922,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline:
|
|||||||
if up1 is not None:
|
if up1 is not None:
|
||||||
uncond_pool = up1
|
uncond_pool = up1
|
||||||
|
|
||||||
dtype = text_embeddings_list[0].dtype
|
dtype = self.unet.dtype
|
||||||
|
|
||||||
# 4. Preprocess image and mask
|
# 4. Preprocess image and mask
|
||||||
if isinstance(image, PIL.Image.Image):
|
if isinstance(image, PIL.Image.Image):
|
||||||
|
|||||||
@@ -3874,127 +3874,127 @@ def sample_images_common(
|
|||||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with accelerator.autocast():
|
# with accelerator.autocast():
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
if not accelerator.is_main_process:
|
if not accelerator.is_main_process:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(prompt, dict):
|
if isinstance(prompt, dict):
|
||||||
negative_prompt = prompt.get("negative_prompt")
|
negative_prompt = prompt.get("negative_prompt")
|
||||||
sample_steps = prompt.get("sample_steps", 30)
|
sample_steps = prompt.get("sample_steps", 30)
|
||||||
width = prompt.get("width", 512)
|
width = prompt.get("width", 512)
|
||||||
height = prompt.get("height", 512)
|
height = prompt.get("height", 512)
|
||||||
scale = prompt.get("scale", 7.5)
|
scale = prompt.get("scale", 7.5)
|
||||||
seed = prompt.get("seed")
|
seed = prompt.get("seed")
|
||||||
controlnet_image = prompt.get("controlnet_image")
|
controlnet_image = prompt.get("controlnet_image")
|
||||||
prompt = prompt.get("prompt")
|
prompt = prompt.get("prompt")
|
||||||
else:
|
else:
|
||||||
# prompt = prompt.strip()
|
# prompt = prompt.strip()
|
||||||
# if len(prompt) == 0 or prompt[0] == "#":
|
# if len(prompt) == 0 or prompt[0] == "#":
|
||||||
# continue
|
# continue
|
||||||
|
|
||||||
# subset of gen_img_diffusers
|
# subset of gen_img_diffusers
|
||||||
prompt_args = prompt.split(" --")
|
prompt_args = prompt.split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
negative_prompt = None
|
negative_prompt = None
|
||||||
sample_steps = 30
|
sample_steps = 30
|
||||||
width = height = 512
|
width = height = 512
|
||||||
scale = 7.5
|
scale = 7.5
|
||||||
seed = None
|
seed = None
|
||||||
controlnet_image = None
|
controlnet_image = None
|
||||||
for parg in prompt_args:
|
for parg in prompt_args:
|
||||||
try:
|
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
width = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
height = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
seed = int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m: # steps
|
|
||||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
|
||||||
if m: # scale
|
|
||||||
scale = float(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
negative_prompt = m.group(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
controlnet_image = m.group(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
except ValueError as ex:
|
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
|
||||||
print(ex)
|
|
||||||
|
|
||||||
if seed is not None:
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
if prompt_replacement is not None:
|
|
||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
|
||||||
if negative_prompt is not None:
|
|
||||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
|
||||||
|
|
||||||
if controlnet_image is not None:
|
|
||||||
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
|
||||||
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
|
||||||
|
|
||||||
height = max(64, height - height % 8) # round to divisible by 8
|
|
||||||
width = max(64, width - width % 8) # round to divisible by 8
|
|
||||||
print(f"prompt: {prompt}")
|
|
||||||
print(f"negative_prompt: {negative_prompt}")
|
|
||||||
print(f"height: {height}")
|
|
||||||
print(f"width: {width}")
|
|
||||||
print(f"sample_steps: {sample_steps}")
|
|
||||||
print(f"scale: {scale}")
|
|
||||||
image = pipeline(
|
|
||||||
prompt=prompt,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
num_inference_steps=sample_steps,
|
|
||||||
guidance_scale=scale,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
controlnet=controlnet,
|
|
||||||
controlnet_image=controlnet_image,
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
||||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
|
||||||
seed_suffix = "" if seed is None else f"_{seed}"
|
|
||||||
img_filename = (
|
|
||||||
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
|
||||||
)
|
|
||||||
|
|
||||||
image.save(os.path.join(save_dir, img_filename))
|
|
||||||
|
|
||||||
# wandb有効時のみログを送信
|
|
||||||
try:
|
|
||||||
wandb_tracker = accelerator.get_tracker("wandb")
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
if m:
|
||||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
width = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
except: # wandb 無効時
|
if m:
|
||||||
pass
|
height = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m:
|
||||||
|
seed = int(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
|
if m: # steps
|
||||||
|
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
|
if m: # scale
|
||||||
|
scale = float(m.group(1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
negative_prompt = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||||
|
if m: # negative prompt
|
||||||
|
controlnet_image = m.group(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except ValueError as ex:
|
||||||
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
|
print(ex)
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
|
||||||
|
if prompt_replacement is not None:
|
||||||
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
if negative_prompt is not None:
|
||||||
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||||
|
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||||
|
|
||||||
|
height = max(64, height - height % 8) # round to divisible by 8
|
||||||
|
width = max(64, width - width % 8) # round to divisible by 8
|
||||||
|
print(f"prompt: {prompt}")
|
||||||
|
print(f"negative_prompt: {negative_prompt}")
|
||||||
|
print(f"height: {height}")
|
||||||
|
print(f"width: {width}")
|
||||||
|
print(f"sample_steps: {sample_steps}")
|
||||||
|
print(f"scale: {scale}")
|
||||||
|
image = pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_inference_steps=sample_steps,
|
||||||
|
guidance_scale=scale,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
controlnet=controlnet,
|
||||||
|
controlnet_image=controlnet_image,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||||
|
seed_suffix = "" if seed is None else f"_{seed}"
|
||||||
|
img_filename = (
|
||||||
|
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
|
|
||||||
|
# wandb有効時のみログを送信
|
||||||
|
try:
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||||
|
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||||
|
|
||||||
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||||
|
except: # wandb 無効時
|
||||||
|
pass
|
||||||
|
|
||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user