fix: bring positional args back, add recursive to blip etc

This commit is contained in:
Linaqruf
2023-04-12 05:41:28 +07:00
parent bf8088e225
commit c316c63dff
4 changed files with 35 additions and 16 deletions

View File

@@ -2,6 +2,8 @@ import argparse
import os
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import numpy as np
from PIL import Image
@@ -41,14 +43,22 @@ def get_latents(vae, images, weight_dtype):
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
else:
base_name = image_key
relative_path = ""
if flip:
base_name += '_flip'
return os.path.join(data_dir, base_name)
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name)
else:
return os.path.join(data_dir, base_name)
def main(args):
@@ -56,7 +66,8 @@ def main(args):
if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
image_paths = train_util.glob_images(args.train_data_dir)
train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json):
@@ -99,7 +110,7 @@ def main(args):
f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
np.savez(npz_file_name, latent)
# flip
@@ -107,12 +118,12 @@ def main(args):
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive)
np.savez(npz_file_name, latent)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
@@ -169,9 +180,9 @@ def main(args):
# 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz")
found = True
for npz_file in npz_files:
@@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
return parser