mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
merge image_dir for metadata editor
This commit is contained in:
@@ -68,7 +68,8 @@ def train(args):
|
||||
user_config = config_util.load_user_config(args.config_file)
|
||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
print("ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
||||
print(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(', '.join(ignored)))
|
||||
else:
|
||||
if use_dreambooth_method:
|
||||
print("Use DreamBooth method.")
|
||||
@@ -99,7 +100,8 @@ def train(args):
|
||||
return
|
||||
|
||||
if cache_latents:
|
||||
assert train_dataset_group.is_latent_cacheable(), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
assert train_dataset_group.is_latent_cacheable(
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
# acceleratorを準備する
|
||||
print("prepare accelerator")
|
||||
@@ -255,10 +257,11 @@ def train(args):
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||
#print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
# TODO refactor metadata creation and move to util
|
||||
metadata = {
|
||||
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
|
||||
"ss_training_started_at": training_started_at, # unix timestamp
|
||||
@@ -304,6 +307,7 @@ def train(args):
|
||||
# or should also pack nested collections as json?
|
||||
datasets_metadata = []
|
||||
tag_frequency = {} # merge tag frequency for metadata editor
|
||||
dataset_dirs_info = {} # merge subset dirs for metadata editor
|
||||
|
||||
for dataset in train_dataset_group.datasets:
|
||||
is_dreambooth_dataset = isinstance(dataset, DreamBoothDataset)
|
||||
@@ -331,17 +335,41 @@ def train(args):
|
||||
"shuffle_caption": bool(subset.shuffle_caption),
|
||||
"keep_tokens": subset.keep_tokens,
|
||||
}
|
||||
|
||||
image_dir_or_metadata_file = None
|
||||
if subset.image_dir:
|
||||
subset_metadata["image_dir"] = os.path.basename(subset.image_dir)
|
||||
image_dir = os.path.basename(subset.image_dir)
|
||||
subset_metadata["image_dir"] = image_dir
|
||||
image_dir_or_metadata_file = image_dir
|
||||
|
||||
if is_dreambooth_dataset:
|
||||
subset_metadata["class_tokens"] = subset.class_tokens
|
||||
subset_metadata["is_reg"] = subset.is_reg
|
||||
if subset.is_reg:
|
||||
image_dir_or_metadata_file = None # not merging reg dataset
|
||||
else:
|
||||
subset_metadata["metadata_file"] = os.path.basename(subset.metadata_file)
|
||||
metadata_file = os.path.basename(subset.metadata_file)
|
||||
subset_metadata["metadata_file"] = metadata_file
|
||||
image_dir_or_metadata_file = metadata_file # may overwrite
|
||||
|
||||
subsets_metadata.append(subset_metadata)
|
||||
|
||||
# merge dataset dir: not reg subset only
|
||||
# TODO update additional-network extension to show detailed dataset config from metadata
|
||||
if image_dir_or_metadata_file is not None:
|
||||
# datasets may have a certain dir multiple times
|
||||
v = image_dir_or_metadata_file
|
||||
i = 2
|
||||
while v in dataset_dirs_info:
|
||||
v = image_dir_or_metadata_file + f" ({i})"
|
||||
i += 1
|
||||
image_dir_or_metadata_file = v
|
||||
|
||||
dataset_dirs_info[image_dir_or_metadata_file] = {
|
||||
"n_repeats": subset.num_repeats,
|
||||
"img_count": subset.img_count
|
||||
}
|
||||
|
||||
dataset_metadata["subsets"] = subsets_metadata
|
||||
datasets_metadata.append(dataset_metadata)
|
||||
|
||||
@@ -356,9 +384,11 @@ def train(args):
|
||||
|
||||
metadata["ss_datasets"] = json.dumps(datasets_metadata)
|
||||
metadata["ss_tag_frequency"] = json.dumps(tag_frequency)
|
||||
metadata["ss_dataset_dirs"] = json.dumps(dataset_dirs_info)
|
||||
else:
|
||||
# conserving backward compatibility when using train_dataset_dir and reg_dataset_dir
|
||||
assert len(train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
||||
assert len(
|
||||
train_dataset_group.datasets) == 1, f"There should be a single dataset but {len(train_dataset_group.datasets)} found. This seems to be a bug. / データセットは1個だけ存在するはずですが、実際には{len(train_dataset_group.datasets)}個でした。プログラムのバグかもしれません。"
|
||||
|
||||
dataset = train_dataset_group.datasets[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user