add prompt option '--f' for filename

This commit is contained in:
Kohya S
2024-05-15 20:21:49 +09:00
parent 589c2aa025
commit 153764a687
2 changed files with 43 additions and 15 deletions

View File

@@ -179,6 +179,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821!
- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving.
- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。
- optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。
- `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。
@@ -219,6 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します
- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247)
- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。
### Apr 7, 2024 / 2024-04-07: v0.8.7

View File

@@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple):
clip_prompt: str
guide_image: Any
raw_prompt: str
file_name: Optional[str]
class BatchDataExt(NamedTuple):
@@ -2316,7 +2317,7 @@ def main(args):
# このバッチの情報を取り出す
(
return_latents,
(step_first, _, _, _, init_image, mask_image, _, guide_image, _),
(step_first, _, _, _, init_image, mask_image, _, guide_image, _, _),
(
width,
height,
@@ -2339,6 +2340,7 @@ def main(args):
prompts = []
negative_prompts = []
raw_prompts = []
filenames = []
start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
noises = [
torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype)
@@ -2371,7 +2373,7 @@ def main(args):
all_guide_images_are_same = True
for i, (
_,
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt),
(_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename),
_,
) in enumerate(batch):
prompts.append(prompt)
@@ -2379,6 +2381,7 @@ def main(args):
seeds.append(seed)
clip_prompts.append(clip_prompt)
raw_prompts.append(raw_prompt)
filenames.append(filename)
if init_image is not None:
init_images.append(init_image)
@@ -2478,8 +2481,8 @@ def main(args):
# save image
highres_prefix = ("0" if highres_1st else "1") if highres_fix else ""
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts)
for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate(
zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames)
):
if highres_fix:
seed -= 1 # record original seed
@@ -2505,17 +2508,23 @@ def main(args):
metadata.add_text("crop-top", str(crop_top))
metadata.add_text("crop-left", str(crop_left))
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
if filename is not None:
fln = filename
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
if args.use_original_file_name and init_images is not None:
if type(init_images) is list:
fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png"
else:
fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png"
elif args.sequential_file_name:
fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png"
else:
fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if fln.endswith(".webp"):
image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy
else:
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
if not args.no_preview and not highres_1st and args.interactive:
try:
@@ -2562,6 +2571,7 @@ def main(args):
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0]
filename = None
if pi == 0 or len(raw_prompts) > 1:
# parse prompt: if prompt is not changed, skip parsing
@@ -2783,6 +2793,12 @@ def main(args):
logger.info(f"gradual latent unsharp params: {gl_unsharp_params}")
continue
m = re.match(r"f (.+)", parg, re.IGNORECASE)
if m: # filename
filename = m.group(1)
logger.info(f"filename: {filename}")
continue
except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(f"{ex}")
@@ -2873,7 +2889,16 @@ def main(args):
b1 = BatchData(
False,
BatchDataBase(
global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt
global_step,
prompt,
negative_prompt,
seed,
init_image,
mask_image,
clip_prompt,
guide_image,
raw_prompt,
filename,
),
BatchDataExt(
width,