mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add tag frequency metadata
This commit is contained in:
@@ -85,6 +85,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.enable_bucket = False
|
self.enable_bucket = False
|
||||||
self.min_bucket_reso = None
|
self.min_bucket_reso = None
|
||||||
self.max_bucket_reso = None
|
self.max_bucket_reso = None
|
||||||
|
self.tag_frequency = {}
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
@@ -520,6 +521,15 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
cap_for_img = read_caption(img_path)
|
cap_for_img = read_caption(img_path)
|
||||||
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
|
||||||
|
|
||||||
|
frequency_for_dir = self.tag_frequency.get(os.path.basename(dir), {})
|
||||||
|
self.tag_frequency[os.path.basename(dir)] = frequency_for_dir
|
||||||
|
for caption in captions:
|
||||||
|
for tag in caption.split(","):
|
||||||
|
if tag and not tag.isspace():
|
||||||
|
tag = tag.lower()
|
||||||
|
frequency = frequency_for_dir.get(tag, 0)
|
||||||
|
frequency_for_dir[tag] = frequency + 1
|
||||||
|
|
||||||
return n_repeats, img_paths, captions
|
return n_repeats, img_paths, captions
|
||||||
|
|
||||||
print("prepare train images.")
|
print("prepare train images.")
|
||||||
|
|||||||
@@ -264,6 +264,7 @@ def train(args):
|
|||||||
"ss_keep_tokens": args.keep_tokens,
|
"ss_keep_tokens": args.keep_tokens,
|
||||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||||
|
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
||||||
"ss_training_comment": args.training_comment # will not be updated after training
|
"ss_training_comment": args.training_comment # will not be updated after training
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user