mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix: bring positional args back, add recursive to blip etc
This commit is contained in:
@@ -2,6 +2,8 @@ import argparse
|
||||
import os
|
||||
import json
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -41,14 +43,22 @@ def get_latents(vae, images, weight_dtype):
|
||||
return latents
|
||||
|
||||
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
else:
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
if flip:
|
||||
base_name += '_flip'
|
||||
return os.path.join(data_dir, base_name)
|
||||
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name)
|
||||
else:
|
||||
return os.path.join(data_dir, base_name)
|
||||
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -56,7 +66,8 @@ def main(args):
|
||||
if args.bucket_reso_steps % 8 > 0:
|
||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||
|
||||
image_paths = train_util.glob_images(args.train_data_dir)
|
||||
train_data_dir_path = Path(args.train_data_dir)
|
||||
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||
print(f"found {len(image_paths)} images.")
|
||||
|
||||
if os.path.exists(args.in_json):
|
||||
@@ -99,7 +110,7 @@ def main(args):
|
||||
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
|
||||
np.savez(npz_file_name, latent)
|
||||
|
||||
# flip
|
||||
@@ -107,12 +118,12 @@ def main(args):
|
||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, _), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive)
|
||||
np.savez(npz_file_name, latent)
|
||||
else:
|
||||
# remove existing flipped npz
|
||||
for image_key, _ in bucket:
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
||||
if os.path.isfile(npz_file_name):
|
||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||
os.remove(npz_file_name)
|
||||
@@ -169,9 +180,9 @@ def main(args):
|
||||
|
||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
|
||||
if args.flip_aug:
|
||||
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
||||
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz")
|
||||
|
||||
found = True
|
||||
for npz_file in npz_files:
|
||||
@@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||
parser.add_argument("--skip_existing", action="store_true",
|
||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||
parser.add_argument("--recursive", action="store_true",
|
||||
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user