mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
conditional caption dropout (in progress)
This commit is contained in:
@@ -113,7 +113,7 @@ class BucketManager():
|
||||
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
||||
self.predefined_resos = resos.copy()
|
||||
self.predefined_resos_set = set(resos)
|
||||
self.predifined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
|
||||
def add_if_new_reso(self, reso):
|
||||
if reso not in self.reso_to_id:
|
||||
@@ -135,7 +135,7 @@ class BucketManager():
|
||||
if reso in self.predefined_resos_set:
|
||||
pass
|
||||
else:
|
||||
ar_errors = self.predifined_aspect_ratios - aspect_ratio
|
||||
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
||||
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
||||
reso = self.predefined_resos[predefined_bucket_id]
|
||||
|
||||
@@ -223,9 +223,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.epoch_current:int = int(0)
|
||||
self.dropout_rate:float = 0
|
||||
self.dropout_every_n_epochs:int = 0
|
||||
# TODO 外から渡したほうが安心だが自動で計算したほうが呼ぶ側に余分なコードがいらないのでよさそう
|
||||
self.epoch_current: int = int(0)
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
|
||||
# augmentation
|
||||
flip_p = 0.5 if flip_aug else 0.0
|
||||
@@ -251,6 +252,12 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs):
|
||||
# 将来的にタグのドロップアウトも対応したいのでメソッドを生やしておく
|
||||
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
self.tag_frequency[dir_name] = frequency_for_dir
|
||||
@@ -604,9 +611,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# dropoutの決定
|
||||
is_drop_out = False
|
||||
if self.dropout_rate > 0 and self.dropout_rate < random.random() :
|
||||
if self.dropout_rate > 0 and random.random() < self.dropout_rate:
|
||||
is_drop_out = True
|
||||
if self.dropout_every_n_epochs > 0 and self.epoch_current % self.dropout_every_n_epochs == 0 :
|
||||
if self.dropout_every_n_epochs and self.epoch_current % self.dropout_every_n_epochs == 0:
|
||||
is_drop_out = True
|
||||
|
||||
if is_drop_out:
|
||||
@@ -1391,7 +1398,7 @@ def verify_training_args(args: argparse.Namespace):
|
||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
|
||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
|
||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
||||
# dataset common
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--shuffle_caption", action="store_true",
|
||||
@@ -1421,10 +1428,14 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
parser.add_argument("--dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||
parser.add_argument("--dropout_every_n_epochs", type=int, default=0,
|
||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth dataset
|
||||
|
||||
Reference in New Issue
Block a user