implement token warmup

This commit is contained in:
u-haru
2023-03-23 07:37:14 +09:00
parent 432353185c
commit a9b26b73e0
6 changed files with 75 additions and 2 deletions

View File

@@ -277,6 +277,8 @@ class BaseSubset:
caption_dropout_rate: float,
caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float,int],
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
@@ -290,6 +292,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0
@@ -310,6 +315,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -325,6 +332,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.is_reg = is_reg
@@ -352,6 +361,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -367,6 +378,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
)
self.metadata_file = metadata_file
@@ -405,6 +418,9 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
# augmentation
self.aug_helper = AugHelper()
@@ -424,6 +440,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch = epoch
self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir
@@ -453,7 +475,7 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out:
caption = ""
else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
@@ -474,8 +496,15 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens)
tokens = fixed_tokens + flex_tokens
caption = ", ".join(fixed_tokens + flex_tokens)
if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
tokens = tokens[:tokens_len]
caption = ", ".join(tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
@@ -1249,6 +1278,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self):
for dataset in self.datasets:
dataset.disable_token_padding()
@@ -2001,6 +2038,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_steps",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに