enable comment in prompt file, record raw prompt to metadata

This commit is contained in:
Kohya S
2023-12-12 08:20:36 +09:00
parent 07ef03d340
commit d61ecb26fd
2 changed files with 32 additions and 12 deletions

View File

@@ -2184,6 +2184,7 @@ class BatchDataBase(NamedTuple):
mask_image: Any mask_image: Any
clip_prompt: str clip_prompt: str
guide_image: Any guide_image: Any
raw_prompt: str
class BatchDataExt(NamedTuple): class BatchDataExt(NamedTuple):
@@ -2710,7 +2711,7 @@ def main(args):
print(f"reading prompts from {args.from_file}") print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f: with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines() prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0] prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
elif args.prompt is not None: elif args.prompt is not None:
prompt_list = [args.prompt] prompt_list = [args.prompt]
else: else:
@@ -2954,13 +2955,14 @@ def main(args):
# このバッチの情報を取り出す # このバッチの情報を取り出す
( (
return_latents, return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image), (step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts), (width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
) = batch[0] ) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = [] prompts = []
negative_prompts = [] negative_prompts = []
raw_prompts = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [ noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2991,11 +2993,16 @@ def main(args):
all_images_are_same = True all_images_are_same = True
all_masks_are_same = True all_masks_are_same = True
all_guide_images_are_same = True all_guide_images_are_same = True
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
_,
) in enumerate(batch):
prompts.append(prompt) prompts.append(prompt)
negative_prompts.append(negative_prompt) negative_prompts.append(negative_prompt)
seeds.append(seed) seeds.append(seed)
clip_prompts.append(clip_prompt) clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
if init_image is not None: if init_image is not None:
init_images.append(init_image) init_images.append(init_image)
@@ -3087,8 +3094,8 @@ def main(args):
# save image # save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts) zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
): ):
if highres_fix: if highres_fix:
seed -= 1 # record original seed seed -= 1 # record original seed
@@ -3104,6 +3111,8 @@ def main(args):
metadata.add_text("negative-scale", str(negative_scale)) metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None: if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt) metadata.add_text("clip-prompt", clip_prompt)
if raw_prompt is not None:
metadata.add_text("raw-prompt", raw_prompt)
if args.use_original_file_name and init_images is not None: if args.use_original_file_name and init_images is not None:
if type(init_images) is list: if type(init_images) is list:
@@ -3438,7 +3447,9 @@ def main(args):
b1 = BatchData( b1 = BatchData(
False, False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
),
BatchDataExt( BatchDataExt(
width, width,
height, height,

View File

@@ -1449,6 +1449,7 @@ class BatchDataBase(NamedTuple):
mask_image: Any mask_image: Any
clip_prompt: str clip_prompt: str
guide_image: Any guide_image: Any
raw_prompt: str
class BatchDataExt(NamedTuple): class BatchDataExt(NamedTuple):
@@ -1918,7 +1919,7 @@ def main(args):
print(f"reading prompts from {args.from_file}") print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f: with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines() prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0] prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
elif args.prompt is not None: elif args.prompt is not None:
prompt_list = [args.prompt] prompt_list = [args.prompt]
else: else:
@@ -2190,7 +2191,7 @@ def main(args):
# このバッチの情報を取り出す # このバッチの情報を取り出す
( (
return_latents, return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image), (step_first, _, _, _, init_image, mask_image, _, guide_image, _),
( (
width, width,
height, height,
@@ -2212,6 +2213,7 @@ def main(args):
prompts = [] prompts = []
negative_prompts = [] negative_prompts = []
raw_prompts = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [ noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2242,11 +2244,16 @@ def main(args):
all_images_are_same = True all_images_are_same = True
all_masks_are_same = True all_masks_are_same = True
all_guide_images_are_same = True all_guide_images_are_same = True
for i, (_, (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), _) in enumerate(batch): for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
_,
) in enumerate(batch):
prompts.append(prompt) prompts.append(prompt)
negative_prompts.append(negative_prompt) negative_prompts.append(negative_prompt)
seeds.append(seed) seeds.append(seed)
clip_prompts.append(clip_prompt) clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
if init_image is not None: if init_image is not None:
init_images.append(init_image) init_images.append(init_image)
@@ -2344,8 +2351,8 @@ def main(args):
# save image # save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt) in enumerate( for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts) zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts )
): ):
if highres_fix: if highres_fix:
seed -= 1 # record original seed seed -= 1 # record original seed
@@ -2361,6 +2368,8 @@ def main(args):
metadata.add_text("negative-scale", str(negative_scale)) metadata.add_text("negative-scale", str(negative_scale))
if clip_prompt is not None: if clip_prompt is not None:
metadata.add_text("clip-prompt", clip_prompt) metadata.add_text("clip-prompt", clip_prompt)
if raw_prompt is not None:
metadata.add_text("raw-prompt", raw_prompt)
metadata.add_text("original-height", str(original_height)) metadata.add_text("original-height", str(original_height))
metadata.add_text("original-width", str(original_width)) metadata.add_text("original-width", str(original_width))
metadata.add_text("original-height-negative", str(original_height_negative)) metadata.add_text("original-height-negative", str(original_height_negative))
@@ -2736,7 +2745,7 @@ def main(args):
b1 = BatchData( b1 = BatchData(
False, False,
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
BatchDataExt( BatchDataExt(
width, width,
height, height,