implement token warmup

This commit is contained in:
u-haru
2023-03-23 07:37:14 +09:00
parent 432353185c
commit a9b26b73e0
6 changed files with 75 additions and 2 deletions

View File

@@ -197,6 +197,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -263,6 +266,7 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)

View File

@@ -56,6 +56,8 @@ class BaseSubsetParams:
caption_dropout_rate: float = 0.0 caption_dropout_rate: float = 0.0
caption_dropout_every_n_epochs: int = 0 caption_dropout_every_n_epochs: int = 0
caption_tag_dropout_rate: float = 0.0 caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: Union[float,int] = 0
@dataclass @dataclass
class DreamBoothSubsetParams(BaseSubsetParams): class DreamBoothSubsetParams(BaseSubsetParams):
@@ -137,6 +139,8 @@ class ConfigSanitizer:
"random_crop": bool, "random_crop": bool,
"shuffle_caption": bool, "shuffle_caption": bool,
"keep_tokens": int, "keep_tokens": int,
"token_warmup_min": int,
"token_warmup_step": Union[float,int],
} }
# DO means DropOut # DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = { DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug} flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range} face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop} random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""), " ") """), " ")
if is_dreambooth: if is_dreambooth:

View File

@@ -277,6 +277,8 @@ class BaseSubset:
caption_dropout_rate: float, caption_dropout_rate: float,
caption_dropout_every_n_epochs: int, caption_dropout_every_n_epochs: int,
caption_tag_dropout_rate: float, caption_tag_dropout_rate: float,
token_warmup_min: int,
token_warmup_step: Union[float,int],
) -> None: ) -> None:
self.image_dir = image_dir self.image_dir = image_dir
self.num_repeats = num_repeats self.num_repeats = num_repeats
@@ -290,6 +292,9 @@ class BaseSubset:
self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs
self.caption_tag_dropout_rate = caption_tag_dropout_rate self.caption_tag_dropout_rate = caption_tag_dropout_rate
self.token_warmup_min = token_warmup_min # step=0におけるタグの数
self.token_warmup_step = token_warmup_step # NN<1ならN*max_train_stepsステップ目でタグの数が最大になる
self.img_count = 0 self.img_count = 0
@@ -310,6 +315,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -325,6 +332,8 @@ class DreamBoothSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.is_reg = is_reg self.is_reg = is_reg
@@ -352,6 +361,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) -> None: ) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -367,6 +378,8 @@ class FineTuningSubset(BaseSubset):
caption_dropout_rate, caption_dropout_rate,
caption_dropout_every_n_epochs, caption_dropout_every_n_epochs,
caption_tag_dropout_rate, caption_tag_dropout_rate,
token_warmup_min,
token_warmup_step,
) )
self.metadata_file = metadata_file self.metadata_file = metadata_file
@@ -405,6 +418,9 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
self.max_train_steps: int = 0
# augmentation # augmentation
self.aug_helper = AugHelper() self.aug_helper = AugHelper()
@@ -424,6 +440,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.current_epoch = epoch self.current_epoch = epoch
self.shuffle_buckets() self.shuffle_buckets()
def set_current_step(self, step):
self.current_step = step
def set_max_train_steps(self, max_train_steps):
self.max_train_steps = max_train_steps
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir self.tag_frequency[dir_name] = frequency_for_dir
@@ -453,7 +475,7 @@ class BaseDataset(torch.utils.data.Dataset):
if is_drop_out: if is_drop_out:
caption = "" caption = ""
else: else:
if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
def dropout_tags(tokens): def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0: if subset.caption_tag_dropout_rate <= 0:
@@ -474,8 +496,15 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(flex_tokens) random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens) flex_tokens = dropout_tags(flex_tokens)
tokens = fixed_tokens + flex_tokens
caption = ", ".join(fixed_tokens + flex_tokens) if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min
tokens = tokens[:tokens_len]
caption = ", ".join(tokens)
# textual inversion対応 # textual inversion対応
for str_from, str_to in self.replacements.items(): for str_from, str_to in self.replacements.items():
@@ -1249,6 +1278,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets: for dataset in self.datasets:
dataset.set_current_epoch(epoch) dataset.set_current_epoch(epoch)
def set_current_step(self, step):
for dataset in self.datasets:
dataset.set_current_step(step)
def set_max_train_steps(self, max_train_steps):
for dataset in self.datasets:
dataset.set_max_train_steps(max_train_steps)
def disable_token_padding(self): def disable_token_padding(self):
for dataset in self.datasets: for dataset in self.datasets:
dataset.disable_token_padding() dataset.disable_token_padding()
@@ -2001,6 +2038,20 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
) )
parser.add_argument(
"--token_warmup_min",
type=int,
default=1,
help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する",
)
parser.add_argument(
"--token_warmup_steps",
type=float,
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
if support_caption_dropout: if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない # Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに

View File

@@ -162,6 +162,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
if args.stop_text_encoder_training is None: if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
@@ -246,6 +249,7 @@ def train(args):
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
with accelerator.accumulate(unet): with accelerator.accumulate(unet):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
# latentに変換 # latentに変換
if cache_latents: if cache_latents:

View File

@@ -200,6 +200,9 @@ def train(args):
if is_main_process: if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -505,6 +508,7 @@ def train(args):
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(network): with accelerator.accumulate(network):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)

View File

@@ -260,6 +260,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -338,6 +341,7 @@ def train(args):
loss_total = 0 loss_total = 0
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder): with accelerator.accumulate(text_encoder):
train_dataset_group.set_current_step(step + 1)
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device) latents = batch["latents"].to(accelerator.device)