diff --git a/fine_tune.py b/fine_tune.py index d927bd73..473a13ec 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -197,6 +197,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -263,6 +266,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/library/config_util.py b/library/config_util.py index e62bfb89..98d89b7e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -56,6 +56,8 @@ class BaseSubsetParams: caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: Union[float,int] = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -137,6 +139,8 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "token_warmup_min": int, + "token_warmup_step": Union[float,int], } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, """), " ") if is_dreambooth: diff --git a/library/train_util.py b/library/train_util.py index 7d311827..52b51314 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -277,6 +277,8 @@ class BaseSubset: caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float,int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -290,6 +292,9 @@ class BaseSubset: self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.img_count = 0 @@ -310,6 +315,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -325,6 +332,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -352,6 +361,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -367,6 +378,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -405,6 +418,9 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.current_step: int = 0 + self.max_train_steps: int = 0 + # augmentation self.aug_helper = AugHelper() @@ -424,6 +440,12 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch = epoch self.shuffle_buckets() + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -453,7 +475,7 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -474,8 +496,15 @@ class BaseDataset(torch.utils.data.Dataset): random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) + tokens = fixed_tokens + flex_tokens - caption = ", ".join(fixed_tokens + flex_tokens) + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] + + caption = ", ".join(tokens) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -1249,6 +1278,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_current_epoch(epoch) + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() @@ -2001,6 +2038,20 @@ def add_dataset_arguments( "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ) + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", + ) + + parser.add_argument( + "--token_warmup_steps", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに diff --git a/train_db.py b/train_db.py index 81aeda19..164e354e 100644 --- a/train_db.py +++ b/train_db.py @@ -162,6 +162,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + if args.stop_text_encoder_training is None: args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end @@ -246,6 +249,7 @@ def train(args): text_encoder.requires_grad_(False) with accelerator.accumulate(unet): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): # latentに変換 if cache_latents: diff --git a/train_network.py b/train_network.py index 7f910df4..16f41ebb 100644 --- a/train_network.py +++ b/train_network.py @@ -200,6 +200,9 @@ def train(args): if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -505,6 +508,7 @@ def train(args): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e4ab7b5c..b3467d94 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -260,6 +260,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -338,6 +341,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device)