mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into main
This commit is contained in:
@@ -2,25 +2,34 @@
|
|||||||
# (c) 2022 Kohya S. @kohya_ss
|
# (c) 2022 Kohya S. @kohya_ss
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import glob
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
image_paths = None
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
train_data_dir_path = Path(args.train_data_dir)
|
||||||
|
if args.recursive:
|
||||||
|
image_paths = list(train_data_dir_path.rglob('*.jpg')) + \
|
||||||
|
list(train_data_dir_path.rglob('*.jpeg')) + \
|
||||||
|
list(train_data_dir_path.rglob('*.png')) + \
|
||||||
|
list(train_data_dir_path.rglob('*.webp'))
|
||||||
|
else:
|
||||||
|
image_paths = list(train_data_dir_path.glob('*.jpg')) + \
|
||||||
|
list(train_data_dir_path.glob('*.jpeg')) + \
|
||||||
|
list(train_data_dir_path.glob('*.png')) + \
|
||||||
|
list(train_data_dir_path.glob('*.webp'))
|
||||||
|
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
if args.in_json is None and os.path.isfile(args.out_json):
|
if args.in_json is None and Path(args.out_json).is_file():
|
||||||
args.in_json = args.out_json
|
args.in_json = args.out_json
|
||||||
|
|
||||||
if args.in_json is not None:
|
if args.in_json is not None:
|
||||||
print(f"loading existing metadata: {args.in_json}")
|
print(f"loading existing metadata: {args.in_json}")
|
||||||
with open(args.in_json, "rt", encoding='utf-8') as f:
|
metadata = json.loads(Path(args.in_json).read_text(encoding='utf-8'))
|
||||||
metadata = json.load(f)
|
|
||||||
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
|
||||||
else:
|
else:
|
||||||
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
print("new metadata will be created / 新しいメタデータファイルが作成されます")
|
||||||
@@ -28,22 +37,21 @@ def main(args):
|
|||||||
|
|
||||||
print("merge tags to metadata json.")
|
print("merge tags to metadata json.")
|
||||||
for image_path in tqdm(image_paths):
|
for image_path in tqdm(image_paths):
|
||||||
tags_path = os.path.splitext(image_path)[0] + '.txt'
|
tags_path = image_path.with_suffix('.txt')
|
||||||
with open(tags_path, "rt", encoding='utf-8') as f:
|
tags = tags_path.read_text(encoding='utf-8').strip()
|
||||||
tags = f.readlines()[0].strip()
|
|
||||||
|
|
||||||
image_key = image_path if args.full_path else os.path.splitext(os.path.basename(image_path))[0]
|
image_key = image_path if args.full_path else image_path.stem
|
||||||
if image_key not in metadata:
|
if str(image_key) not in metadata:
|
||||||
metadata[image_key] = {}
|
metadata[str(image_key)] = {}
|
||||||
|
|
||||||
metadata[image_key]['tags'] = tags
|
metadata[str(image_key)]['tags'] = tags
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(image_key, tags)
|
print(image_key, tags)
|
||||||
|
|
||||||
# metadataを書き出して終わり
|
# metadataを書き出して終わり
|
||||||
print(f"writing metadata: {args.out_json}")
|
print(f"writing metadata: {args.out_json}")
|
||||||
with open(args.out_json, "wt", encoding='utf-8') as f:
|
Path(args.out_json).write_text(json.dumps(metadata, indent=2), encoding='utf-8')
|
||||||
json.dump(metadata, f, indent=2)
|
|
||||||
print("done!")
|
print("done!")
|
||||||
|
|
||||||
|
|
||||||
@@ -54,6 +62,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
parser.add_argument("--in_json", type=str, help="metadata file to input (if omitted and out_json exists, existing out_json is read) / 読み込むメタデータファイル(省略時、out_jsonが存在すればそれを読み込む)")
|
||||||
parser.add_argument("--full_path", action="store_true",
|
parser.add_argument("--full_path", action="store_true",
|
||||||
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)")
|
||||||
|
parser.add_argument("--recursive", action="store_true", help="recursively look for training tags in all child folders of train_data_dir / train_data_dirのすべての子フォルダにある学習タグを再帰的に探す")
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
parser.add_argument("--debug", action="store_true", help="debug mode, print tags")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
@@ -36,8 +36,11 @@ def main(args):
|
|||||||
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
args.model_dir, SUB_DIR), force_download=True, force_filename=file)
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
glob.glob(os.path.join(args.train_data_dir, "*.jpeg")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.webp")) + \
|
||||||
|
glob.glob(os.path.join(args.train_data_dir, "*.bmp"))
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
print("loading model and labels")
|
print("loading model and labels")
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.enable_bucket = False
|
self.enable_bucket = False
|
||||||
self.min_bucket_reso = None
|
self.min_bucket_reso = None
|
||||||
self.max_bucket_reso = None
|
self.max_bucket_reso = None
|
||||||
|
self.tag_frequency = {}
|
||||||
self.bucket_info = None
|
self.bucket_info = None
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
@@ -545,6 +546,15 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
cap_for_img = read_caption(img_path)
|
cap_for_img = read_caption(img_path)
|
||||||
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
||||||
|
|
||||||
|
frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {})
|
||||||
|
self.tag_frequency[os.path.basename(dir)] = frequency_for_dir
|
||||||
|
for caption in captions:
|
||||||
|
for tag in caption.split(","):
|
||||||
|
if tag and not tag.isspace():
|
||||||
|
tag = tag.lower()
|
||||||
|
frequency = frequency_for_dir.get(tag, 0)
|
||||||
|
frequency_for_dir[tag] = frequency + 1
|
||||||
|
|
||||||
return n_repeats, img_paths, captions
|
return n_repeats, img_paths, captions
|
||||||
|
|
||||||
print("prepare train images.")
|
print("prepare train images.")
|
||||||
|
|||||||
@@ -335,6 +335,7 @@ def train(args):
|
|||||||
"ss_keep_tokens": args.keep_tokens,
|
"ss_keep_tokens": args.keep_tokens,
|
||||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||||
|
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
||||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||||
"ss_training_comment": args.training_comment # will not be updated after training
|
"ss_training_comment": args.training_comment # will not be updated after training
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user