mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'kohya-ss:main' into support-multi-gpu
This commit is contained in:
@@ -113,7 +113,7 @@ class BucketManager():
|
||||
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
||||
self.predefined_resos = resos.copy()
|
||||
self.predefined_resos_set = set(resos)
|
||||
self.predifined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
|
||||
def add_if_new_reso(self, reso):
|
||||
if reso not in self.reso_to_id:
|
||||
@@ -135,7 +135,7 @@ class BucketManager():
|
||||
if reso in self.predefined_resos_set:
|
||||
pass
|
||||
else:
|
||||
ar_errors = self.predifined_aspect_ratios - aspect_ratio
|
||||
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
||||
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
||||
reso = self.predefined_resos[predefined_bucket_id]
|
||||
|
||||
@@ -223,6 +223,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
self.tag_dropout_rate: float = 0
|
||||
|
||||
# augmentation
|
||||
flip_p = 0.5 if flip_aug else 0.0
|
||||
if color_aug:
|
||||
@@ -247,6 +252,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
||||
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||
self.tag_dropout_rate = tag_dropout_rate
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
self.tag_frequency[dir_name] = frequency_for_dir
|
||||
@@ -264,27 +278,52 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
if self.shuffle_keep_tokens is None:
|
||||
random.shuffle(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ",".join(tokens).strip()
|
||||
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
||||
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
||||
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
||||
def dropout_tags(tokens):
|
||||
if self.tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
l = []
|
||||
for token in tokens:
|
||||
if random.random() >= self.tag_dropout_rate:
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ", ".join(tokens)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
|
||||
return caption
|
||||
|
||||
@@ -393,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# そのためバッチサイズを画像種類までに制限する
|
||||
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# TODO 正則化画像をepochまたがりで利用する仕組み
|
||||
num_of_image_types = len(set(bucket))
|
||||
bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||
|
||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
||||
#
|
||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# # そのためバッチサイズを画像種類までに制限する
|
||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
||||
# num_of_image_types = len(set(bucket))
|
||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
# for batch_index in range(batch_count):
|
||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
# ↑ここまで
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
@@ -809,6 +856,7 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_reg_images = 0
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
@@ -907,6 +955,8 @@ class FineTuningDataset(BaseDataset):
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
|
||||
train_dataset.set_current_epoch(1)
|
||||
k = 0
|
||||
for i, example in enumerate(train_dataset):
|
||||
if example['latents'] is not None:
|
||||
@@ -1377,7 +1427,7 @@ def verify_training_args(args: argparse.Namespace):
|
||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
|
||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
|
||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
||||
# dataset common
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--shuffle_caption", action="store_true",
|
||||
@@ -1408,6 +1458,16 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth dataset
|
||||
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
||||
|
||||
Reference in New Issue
Block a user