fix: bring positional args back, add recursive to blip etc

This commit is contained in:
Linaqruf
2023-04-12 05:41:28 +07:00
parent bf8088e225
commit c316c63dff
4 changed files with 35 additions and 16 deletions

View File

@@ -4,6 +4,7 @@ import os
import json import json
import random import random
from pathlib import Path
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
@@ -72,7 +73,8 @@ def main(args):
os.chdir('finetune') os.chdir('finetune')
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}") print(f"loading BLIP caption: {args.caption_weights}")
@@ -152,7 +154,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed') parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
return parser return parser

View File

@@ -2,6 +2,7 @@ import argparse
import os import os
import re import re
from pathlib import Path
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -65,7 +66,8 @@ def main(args):
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
print(f"load images from {args.train_data_dir}") print(f"load images from {args.train_data_dir}")
image_paths = train_util.glob_images(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
# できればcacheに依存せず明示的にダウンロードしたい # できればcacheに依存せず明示的にダウンロードしたい
@@ -140,7 +142,8 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument("--remove_words", action="store_true", parser.add_argument("--remove_words", action="store_true",
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する") help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
return parser return parser

View File

@@ -2,6 +2,8 @@ import argparse
import os import os
import json import json
from pathlib import Path
from typing import List
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@@ -41,14 +43,22 @@ def get_latents(vae, images, weight_dtype):
return latents return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip): def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
if is_full_path: if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0] base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
else: else:
base_name = image_key base_name = image_key
relative_path = ""
if flip: if flip:
base_name += '_flip' base_name += '_flip'
return os.path.join(data_dir, base_name)
if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name)
else:
return os.path.join(data_dir, base_name)
def main(args): def main(args):
@@ -56,7 +66,8 @@ def main(args):
if args.bucket_reso_steps % 8 > 0: if args.bucket_reso_steps % 8 > 0:
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります") print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
image_paths = train_util.glob_images(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
if os.path.exists(args.in_json): if os.path.exists(args.in_json):
@@ -99,7 +110,7 @@ def main(args):
f"latent shape {latents.shape}, {bucket[0][1].shape}" f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents): for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
np.savez(npz_file_name, latent) np.savez(npz_file_name, latent)
# flip # flip
@@ -107,12 +118,12 @@ def main(args):
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents): for (image_key, _), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive)
np.savez(npz_file_name, latent) np.savez(npz_file_name, latent)
else: else:
# remove existing flipped npz # remove existing flipped npz
for image_key, _ in bucket: for image_key, _ in bucket:
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz" npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
if os.path.isfile(npz_file_name): if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}") print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name) os.remove(npz_file_name)
@@ -169,9 +180,9 @@ def main(args):
# 既に存在するファイルがあればshapeを確認して同じならskipする # 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing: if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"] npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug: if args.flip_aug:
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz") npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz")
found = True found = True
for npz_file in npz_files: for npz_file in npz_files:
@@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する") help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
parser.add_argument("--skip_existing", action="store_true", parser.add_argument("--skip_existing", action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ") help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ")
parser.add_argument("--recursive", action="store_true",
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
return parser return parser

View File

@@ -10,7 +10,7 @@ import numpy as np
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import torch import torch
import pathlib from pathlib import Path
import library.train_util as train_util import library.train_util as train_util
@@ -103,8 +103,8 @@ def main(args):
# 画像を読み込む # 画像を読み込む
train_data_dir = pathlib.Path(args.train_data_dir) train_data_dir_path = Path(args.train_data_dir)
image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive) image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
tag_freq = {} tag_freq = {}
@@ -205,7 +205,7 @@ def main(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO, parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID") help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model", parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",