mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix: bring positional args back, add recursive to blip etc
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -72,7 +73,8 @@ def main(args):
|
|||||||
os.chdir('finetune')
|
os.chdir('finetune')
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
print(f"load images from {args.train_data_dir}")
|
||||||
image_paths = train_util.glob_images(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
print(f"loading BLIP caption: {args.caption_weights}")
|
print(f"loading BLIP caption: {args.caption_weights}")
|
||||||
@@ -152,7 +154,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||||
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -65,7 +66,8 @@ def main(args):
|
|||||||
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
GenerationMixin._prepare_input_ids_for_generation = _prepare_input_ids_for_generation_patch
|
||||||
|
|
||||||
print(f"load images from {args.train_data_dir}")
|
print(f"load images from {args.train_data_dir}")
|
||||||
image_paths = train_util.glob_images(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
# できればcacheに依存せず明示的にダウンロードしたい
|
# できればcacheに依存せず明示的にダウンロードしたい
|
||||||
@@ -140,7 +142,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--remove_words", action="store_true",
|
parser.add_argument("--remove_words", action="store_true",
|
||||||
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
help="remove like `with the words xxx` from caption / `with the words xxx`のような部分をキャプションから削除する")
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
parser.add_argument("--recursive", action="store_true", help="search for images in subfolders recursively")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -41,14 +43,22 @@ def get_latents(vae, images, weight_dtype):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
|
||||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip):
|
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
||||||
if is_full_path:
|
if is_full_path:
|
||||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||||
|
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||||
else:
|
else:
|
||||||
base_name = image_key
|
base_name = image_key
|
||||||
|
relative_path = ""
|
||||||
|
|
||||||
if flip:
|
if flip:
|
||||||
base_name += '_flip'
|
base_name += '_flip'
|
||||||
return os.path.join(data_dir, base_name)
|
|
||||||
|
if recursive and relative_path:
|
||||||
|
return os.path.join(data_dir, relative_path, base_name)
|
||||||
|
else:
|
||||||
|
return os.path.join(data_dir, base_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@@ -56,7 +66,8 @@ def main(args):
|
|||||||
if args.bucket_reso_steps % 8 > 0:
|
if args.bucket_reso_steps % 8 > 0:
|
||||||
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
print(f"resolution of buckets in training time is a multiple of 8 / 学習時の各bucketの解像度は8単位になります")
|
||||||
|
|
||||||
image_paths = train_util.glob_images(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
image_paths: List[str] = [str(p) for p in train_util.glob_images_pathlib(train_data_dir_path, args.recursive)]
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if os.path.exists(args.in_json):
|
if os.path.exists(args.in_json):
|
||||||
@@ -99,7 +110,7 @@ def main(args):
|
|||||||
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||||
|
|
||||||
for (image_key, _), latent in zip(bucket, latents):
|
for (image_key, _), latent in zip(bucket, latents):
|
||||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False)
|
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
|
||||||
np.savez(npz_file_name, latent)
|
np.savez(npz_file_name, latent)
|
||||||
|
|
||||||
# flip
|
# flip
|
||||||
@@ -107,12 +118,12 @@ def main(args):
|
|||||||
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない
|
||||||
|
|
||||||
for (image_key, _), latent in zip(bucket, latents):
|
for (image_key, _), latent in zip(bucket, latents):
|
||||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True)
|
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive)
|
||||||
np.savez(npz_file_name, latent)
|
np.savez(npz_file_name, latent)
|
||||||
else:
|
else:
|
||||||
# remove existing flipped npz
|
# remove existing flipped npz
|
||||||
for image_key, _ in bucket:
|
for image_key, _ in bucket:
|
||||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz"
|
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
||||||
if os.path.isfile(npz_file_name):
|
if os.path.isfile(npz_file_name):
|
||||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||||
os.remove(npz_file_name)
|
os.remove(npz_file_name)
|
||||||
@@ -169,9 +180,9 @@ def main(args):
|
|||||||
|
|
||||||
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
# 既に存在するファイルがあればshapeを確認して同じならskipする
|
||||||
if args.skip_existing:
|
if args.skip_existing:
|
||||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False) + ".npz"]
|
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
|
||||||
if args.flip_aug:
|
if args.flip_aug:
|
||||||
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True) + ".npz")
|
npz_files.append(get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz")
|
||||||
|
|
||||||
found = True
|
found = True
|
||||||
for npz_file in npz_files:
|
for npz_file in npz_files:
|
||||||
@@ -256,6 +267,8 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する")
|
||||||
parser.add_argument("--skip_existing", action="store_true",
|
parser.add_argument("--skip_existing", action="store_true",
|
||||||
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)")
|
||||||
|
parser.add_argument("--recursive", action="store_true",
|
||||||
|
help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import numpy as np
|
|||||||
from tensorflow.keras.models import load_model
|
from tensorflow.keras.models import load_model
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
import torch
|
import torch
|
||||||
import pathlib
|
from pathlib import Path
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
@@ -103,8 +103,8 @@ def main(args):
|
|||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
|
|
||||||
train_data_dir = pathlib.Path(args.train_data_dir)
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
image_paths = train_util.glob_images_pathlib(train_data_dir, args.recursive)
|
image_paths = train_util.glob_images_pathlib(train_data_dir_path, args.recursive)
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
tag_freq = {}
|
tag_freq = {}
|
||||||
@@ -205,7 +205,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
parser.add_argument("--repo_id", type=str, default=DEFAULT_WD14_TAGGER_REPO,
|
||||||
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
help="repo id for wd14 tagger on Hugging Face / Hugging Faceのwd14 taggerのリポジトリID")
|
||||||
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
parser.add_argument("--model_dir", type=str, default="wd14_tagger_model",
|
||||||
|
|||||||
Reference in New Issue
Block a user