add prompt option '--f' for filename

This commit is contained in:
Kohya S
2024-05-15 20:21:49 +09:00
parent 589c2aa025
commit 153764a687
2 changed files with 43 additions and 15 deletions

View File

@@ -179,6 +179,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving.
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
@@ -219,6 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。
### Apr 7, 2024 / 2024-04-07: v0.8.7 ### Apr 7, 2024 / 2024-04-07: v0.8.7

View File

@@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple):
clip_prompt: str clip_prompt: str
guide_image: Any guide_image: Any
raw_prompt: str raw_prompt: str
file_name: Optional[str]
class BatchDataExt(NamedTuple): class BatchDataExt(NamedTuple):
@@ -2316,7 +2317,7 @@ def main(args):
# このバッチの情報を取り出す # このバッチの情報を取り出す
( (
return_latents, return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image, _), (step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
( (
width, width,
height, height,
@@ -2339,6 +2340,7 @@ def main(args):
prompts = [] prompts = []
negative_prompts = [] negative_prompts = []
raw_prompts = [] raw_prompts = []
filenames = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [ noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2371,7 +2373,7 @@ def main(args):
all_guide_images_are_same = True all_guide_images_are_same = True
for i, ( for i, (
_, _,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
_, _,
) in enumerate(batch): ) in enumerate(batch):
prompts.append(prompt) prompts.append(prompt)
@@ -2379,6 +2381,7 @@ def main(args):
seeds.append(seed) seeds.append(seed)
clip_prompts.append(clip_prompt) clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt) raw_prompts.append(raw_prompt)
filenames.append(filename)
if init_image is not None: if init_image is not None:
init_images.append(init_image) init_images.append(init_image)
@@ -2478,8 +2481,8 @@ def main(args):
# save image # save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
): ):
if highres_fix: if highres_fix:
seed -= 1 # record original seed seed -= 1 # record original seed
@@ -2505,17 +2508,23 @@ def main(args):
metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left)) metadata.add_text("crop-left", str(crop_left))
if args.use_original_file_name and init_images is not None: if filename is not None:
if type(init_images) is list: fln = filename
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else: else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if fln.endswith(".webp"):
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
else:
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if not args.no_preview and not highres_1st and args.interactive: if not args.no_preview and not highres_1st and args.interactive:
try: try:
@@ -2562,6 +2571,7 @@ def main(args):
# repeat prompt # repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
filename = None
if pi == 0 or len(raw_prompts) > 1: if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing # parse prompt: if prompt is not changed, skip parsing
@@ -2783,6 +2793,12 @@ def main(args):
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue continue
m = re.match(r"f (.+)", parg, re.IGNORECASE)
if m: # filename
filename = m.group(1)
logger.info(f"filename: {filename}")
continue
except ValueError as ex: except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}") logger.error(f"{ex}")
@@ -2873,7 +2889,16 @@ def main(args):
b1 = BatchData( b1 = BatchData(
False, False,
BatchDataBase( BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt global_step,
prompt,
negative_prompt,
seed,
init_image,
mask_image,
clip_prompt,
guide_image,
raw_prompt,
filename,
), ),
BatchDataExt( BatchDataExt(
width, width,