From a9b26b73e0134884e945e0f5d15c6e804e046759 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 07:37:14 +0900 Subject: [PATCH 01/10] implement token warmup --- fine_tune.py | 4 +++ library/config_util.py | 6 +++++ library/train_util.py | 55 ++++++++++++++++++++++++++++++++++++-- train_db.py | 4 +++ train_network.py | 4 +++ train_textual_inversion.py | 4 +++ 6 files changed, 75 insertions(+), 2 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index d927bd73..473a13ec 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -197,6 +197,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -263,6 +266,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/library/config_util.py b/library/config_util.py index e62bfb89..98d89b7e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -56,6 +56,8 @@ class BaseSubsetParams: caption_dropout_rate: float = 0.0 caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 + token_warmup_min: int = 1 + token_warmup_step: Union[float,int] = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -137,6 +139,8 @@ class ConfigSanitizer: "random_crop": bool, "shuffle_caption": bool, "keep_tokens": int, + "token_warmup_min": int, + "token_warmup_step": Union[float,int], } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { @@ -406,6 +410,8 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} + token_warmup_min: {subset.token_warmup_min}, + token_warmup_step: {subset.token_warmup_step}, """), " ") if is_dreambooth: diff --git a/library/train_util.py b/library/train_util.py index 7d311827..52b51314 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -277,6 +277,8 @@ class BaseSubset: caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, + token_warmup_min: int, + token_warmup_step: Union[float,int], ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -290,6 +292,9 @@ class BaseSubset: self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate + self.token_warmup_min = token_warmup_min # step=0におけるタグの数 + self.token_warmup_step = token_warmup_step # N(N<1ならN*max_train_steps)ステップ目でタグの数が最大になる + self.img_count = 0 @@ -310,6 +315,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" @@ -325,6 +332,8 @@ class DreamBoothSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -352,6 +361,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" @@ -367,6 +378,8 @@ class FineTuningSubset(BaseSubset): caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -405,6 +418,9 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ + self.current_step: int = 0 + self.max_train_steps: int = 0 + # augmentation self.aug_helper = AugHelper() @@ -424,6 +440,12 @@ class BaseDataset(torch.utils.data.Dataset): self.current_epoch = epoch self.shuffle_buckets() + def set_current_step(self, step): + self.current_step = step + + def set_max_train_steps(self, max_train_steps): + self.max_train_steps = max_train_steps + def set_tag_frequency(self, dir_name, captions): frequency_for_dir = self.tag_frequency.get(dir_name, {}) self.tag_frequency[dir_name] = frequency_for_dir @@ -453,7 +475,7 @@ class BaseDataset(torch.utils.data.Dataset): if is_drop_out: caption = "" else: - if subset.shuffle_caption or subset.caption_tag_dropout_rate > 0: + if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: @@ -474,8 +496,15 @@ class BaseDataset(torch.utils.data.Dataset): random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) + tokens = fixed_tokens + flex_tokens - caption = ", ".join(fixed_tokens + flex_tokens) + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] + + caption = ", ".join(tokens) # textual inversion対応 for str_from, str_to in self.replacements.items(): @@ -1249,6 +1278,14 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.set_current_epoch(epoch) + def set_current_step(self, step): + for dataset in self.datasets: + dataset.set_current_step(step) + + def set_max_train_steps(self, max_train_steps): + for dataset in self.datasets: + dataset.set_max_train_steps(max_train_steps) + def disable_token_padding(self): for dataset in self.datasets: dataset.disable_token_padding() @@ -2001,6 +2038,20 @@ def add_dataset_arguments( "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ) + parser.add_argument( + "--token_warmup_min", + type=int, + default=1, + help="start learning at N tags (token means comma separated strinfloatgs) / タグ数をN個から増やしながら学習する", + ) + + parser.add_argument( + "--token_warmup_steps", + type=float, + default=0, + help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", + ) + if support_caption_dropout: # Textual Inversion はcaptionのdropoutをsupportしない # いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに diff --git a/train_db.py b/train_db.py index 81aeda19..164e354e 100644 --- a/train_db.py +++ b/train_db.py @@ -162,6 +162,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + if args.stop_text_encoder_training is None: args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end @@ -246,6 +249,7 @@ def train(args): text_encoder.requires_grad_(False) with accelerator.accumulate(unet): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): # latentに変換 if cache_latents: diff --git a/train_network.py b/train_network.py index 7f910df4..16f41ebb 100644 --- a/train_network.py +++ b/train_network.py @@ -200,6 +200,9 @@ def train(args): if is_main_process: print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -505,6 +508,7 @@ def train(args): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e4ab7b5c..b3467d94 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -260,6 +260,9 @@ def train(args): args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -338,6 +341,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): + train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) From 447c56bf505c2a84d00e88ac173a1b6961894429 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 09:53:14 +0900 Subject: [PATCH 02/10] =?UTF-8?q?typo=E4=BF=AE=E6=AD=A3=E3=80=81step?= =?UTF-8?q?=E3=82=92global=5Fstep=E3=81=AB=E4=BF=AE=E6=AD=A3=E3=80=81?= =?UTF-8?q?=E3=83=90=E3=82=B0=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 2 +- library/config_util.py | 4 ++-- library/train_util.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 473a13ec..def942fa 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -265,8 +265,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/library/config_util.py b/library/config_util.py index 98d89b7e..84bbf308 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -57,7 +57,7 @@ class BaseSubsetParams: caption_dropout_every_n_epochs: int = 0 caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 - token_warmup_step: Union[float,int] = 0 + token_warmup_step: float = 0 @dataclass class DreamBoothSubsetParams(BaseSubsetParams): @@ -140,7 +140,7 @@ class ConfigSanitizer: "shuffle_caption": bool, "keep_tokens": int, "token_warmup_min": int, - "token_warmup_step": Union[float,int], + "token_warmup_step": Any(float,int), } # DO means DropOut DO_SUBSET_ASCENDABLE_SCHEMA = { diff --git a/library/train_util.py b/library/train_util.py index 52b51314..83e9372b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2046,7 +2046,7 @@ def add_dataset_arguments( ) parser.add_argument( - "--token_warmup_steps", + "--token_warmup_step", type=float, default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", diff --git a/train_db.py b/train_db.py index 164e354e..e17a8b79 100644 --- a/train_db.py +++ b/train_db.py @@ -241,6 +241,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") @@ -249,7 +250,6 @@ def train(args): text_encoder.requires_grad_(False) with accelerator.accumulate(unet): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): # latentに変換 if cache_latents: diff --git a/train_network.py b/train_network.py index 16f41ebb..6d23ab07 100644 --- a/train_network.py +++ b/train_network.py @@ -507,8 +507,8 @@ def train(args): network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(network): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index b3467d94..42746169 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -340,8 +340,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + train_dataset_group.set_current_step(global_step) with accelerator.accumulate(text_encoder): - train_dataset_group.set_current_step(step + 1) with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) From dbadc40ec2eb2de92b21fd3b5aa82994899705cc Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Thu, 23 Mar 2023 12:33:03 +0900 Subject: [PATCH 03/10] =?UTF-8?q?persistent=5Fworkers=E3=82=92=E6=9C=89?= =?UTF-8?q?=E5=8A=B9=E3=81=AB=E3=81=97=E3=81=9F=E9=9A=9B=E3=81=AB=E3=82=AD?= =?UTF-8?q?=E3=83=A3=E3=83=97=E3=82=B7=E3=83=A7=E3=83=B3=E3=81=8C=E5=A4=89?= =?UTF-8?q?=E5=8C=96=E3=81=97=E3=81=AA=E3=81=8F=E3=81=AA=E3=82=8B=E3=83=90?= =?UTF-8?q?=E3=82=B0=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 3 ++- library/config_util.py | 8 ++++++++ train_db.py | 3 ++- train_network.py | 3 ++- train_textual_inversion.py | 3 ++- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index def942fa..ff580435 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -62,6 +62,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: @@ -259,13 +260,13 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/config_util.py b/library/config_util.py index 84bbf308..efeb8016 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,6 +497,14 @@ def load_user_config(file: str) -> dict: return config +def blueprint_args_conflict(args,blueprint:Blueprint): + # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする + for b in blueprint.dataset_group.datasets: + for t in b.subsets: + if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): + print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir)) + t.params.caption_dropout_every_n_epochs = 0 + t.params.token_warmup_step = 0 # for config test if __name__ == "__main__": diff --git a/train_db.py b/train_db.py index e17a8b79..87fe771b 100644 --- a/train_db.py +++ b/train_db.py @@ -57,6 +57,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.no_token_padding: @@ -233,6 +234,7 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -241,7 +243,6 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index 6d23ab07..02a2d925 100644 --- a/train_network.py +++ b/train_network.py @@ -98,6 +98,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) if args.debug_dataset: @@ -501,13 +502,13 @@ def train(args): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 42746169..63b63426 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -183,6 +183,7 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) + config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 @@ -335,12 +336,12 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset_group.set_current_epoch(epoch + 1) + train_dataset_group.set_current_step(global_step) text_encoder.train() loss_total = 0 for step, batch in enumerate(train_dataloader): - train_dataset_group.set_current_step(global_step) with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 143c26e55219ba9fe51fd1b50f8922d2f2de9c8a Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 24 Mar 2023 13:08:56 +0900 Subject: [PATCH 04/10] =?UTF-8?q?=E7=AB=B6=E5=90=88=E6=99=82=E3=81=ABpersi?= =?UTF-8?q?stant=5Fdata=5Floader=E5=81=B4=E3=82=92=E7=84=A1=E5=8A=B9?= =?UTF-8?q?=E3=81=AB=E3=81=99=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB=E5=A4=89?= =?UTF-8?q?=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/config_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index efeb8016..9c8c90c2 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -502,9 +502,8 @@ def blueprint_args_conflict(args,blueprint:Blueprint): for b in blueprint.dataset_group.datasets: for t in b.subsets: if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - print("Warning: %s: caption_dropout_every_n_epochs and token_warmup_step is ignored because --persistent_data_loader_workers option is used / --persistent_data_loader_workersオプションが使われているため、caption_dropout_every_n_epochs及びtoken_warmup_stepは無視されます。"%(t.params.image_dir)) - t.params.caption_dropout_every_n_epochs = 0 - t.params.token_warmup_step = 0 + print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) + args.persistent_data_loader_workers = False # for config test if __name__ == "__main__": From 1b89b2a10e1f623efd3945d422dcd0640ac4f0fd Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Fri, 24 Mar 2023 13:44:30 +0900 Subject: [PATCH 05/10] =?UTF-8?q?=E3=82=B7=E3=83=A3=E3=83=83=E3=83=95?= =?UTF-8?q?=E3=83=AB=E5=89=8D=E3=81=AB=E3=82=BF=E3=82=B0=E3=82=92=E5=88=87?= =?UTF-8?q?=E3=82=8A=E8=A9=B0=E3=82=81=E3=82=8B=E3=82=88=E3=81=86=E3=81=AB?= =?UTF-8?q?=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/train_util.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 83e9372b..d1df9c58 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -477,6 +477,13 @@ class BaseDataset(torch.utils.data.Dataset): else: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: + tokens = [t.strip() for t in caption.strip().split(",")] + if subset.token_warmup_step < 1: + subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) + if subset.token_warmup_step and self.current_step < subset.token_warmup_step: + tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min + tokens = tokens[:tokens_len] + def dropout_tags(tokens): if subset.caption_tag_dropout_rate <= 0: return tokens @@ -487,24 +494,17 @@ class BaseDataset(torch.utils.data.Dataset): return l fixed_tokens = [] - flex_tokens = [t.strip() for t in caption.strip().split(",")] + flex_tokens = tokens[:] if subset.keep_tokens > 0: fixed_tokens = flex_tokens[: subset.keep_tokens] - flex_tokens = flex_tokens[subset.keep_tokens :] + flex_tokens = tokens[subset.keep_tokens :] if subset.shuffle_caption: random.shuffle(flex_tokens) flex_tokens = dropout_tags(flex_tokens) - tokens = fixed_tokens + flex_tokens - if subset.token_warmup_step < 1: - subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) - if subset.token_warmup_step and self.current_step < subset.token_warmup_step: - tokens_len = math.floor((self.current_step)*((len(tokens)-subset.token_warmup_min)/(subset.token_warmup_step)))+subset.token_warmup_min - tokens = tokens[:tokens_len] - - caption = ", ".join(tokens) + caption = ", ".join(fixed_tokens + flex_tokens) # textual inversion対応 for str_from, str_to in self.replacements.items(): From 5ec90990de870a4579721db947c2f74b9ce3ed69 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 01:41:24 +0900 Subject: [PATCH 06/10] =?UTF-8?q?=E3=83=87=E3=83=BC=E3=82=BF=E3=82=BB?= =?UTF-8?q?=E3=83=83=E3=83=88=E3=81=ABepoch=E3=80=81step=E3=81=8C=E9=80=9A?= =?UTF-8?q?=E9=81=94=E3=81=95=E3=82=8C=E3=81=AA=E3=81=84=E3=83=90=E3=82=B0?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_network.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/train_network.py b/train_network.py index 02a2d925..ef10921f 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,18 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + # print("self.current_step:%d"%self.current_step) + # print("dataset_lengh:%d"%len(dataset)) + print("id(self)(collate):%d"%id(self)) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +111,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -186,11 +200,12 @@ def train(args): # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + train_dataloader = torch.utils.data.DataLoader( train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -498,17 +513,18 @@ def train(args): loss_list = [] loss_total = 0.0 + del train_dataset_group for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 292cdb8379f00f87e9f1391a0ff2508a3540bd13 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 01:44:25 +0900 Subject: [PATCH 07/10] =?UTF-8?q?=E3=83=87=E3=83=BC=E3=82=BF=E3=82=BB?= =?UTF-8?q?=E3=83=83=E3=83=88=E3=81=ABepoch=E3=80=81step=E3=81=8C=E9=80=9A?= =?UTF-8?q?=E9=81=94=E3=81=95=E3=82=8C=E3=81=AA=E3=81=84=E3=83=90=E3=82=B0?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/config_util.py | 11 ++++++----- library/train_util.py | 1 + train_network.py | 23 +++++++++++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 9c8c90c2..6817d9a3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -499,11 +499,12 @@ def load_user_config(file: str) -> dict: def blueprint_args_conflict(args,blueprint:Blueprint): # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - for b in blueprint.dataset_group.datasets: - for t in b.subsets: - if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - args.persistent_data_loader_workers = False + # for b in blueprint.dataset_group.datasets: + # for t in b.subsets: + # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): + # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) + # # args.persistent_data_loader_workers = False + return # for config test if __name__ == "__main__": diff --git a/library/train_util.py b/library/train_util.py index d1df9c58..223e403b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -517,6 +517,7 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) + print(self.current_step, self.max_train_steps, caption) return caption def get_input_ids(self, caption): diff --git a/train_network.py b/train_network.py index 02a2d925..e148a92a 100644 --- a/train_network.py +++ b/train_network.py @@ -8,6 +8,7 @@ import random import time import json import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -24,9 +25,15 @@ from library.config_util import ( BlueprintGenerator, ) - -def collate_fn(examples): - return examples[0] +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] # TODO 他のスクリプトと共通化する @@ -101,6 +108,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -190,7 +201,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -501,14 +512,14 @@ def train(args): for epoch in range(num_train_epochs): if is_main_process: print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 metadata["ss_epoch"] = str(epoch + 1) network.on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 4dc1124f9339af9886f4a73db49e9e7c9cf17a23 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 02:19:55 +0900 Subject: [PATCH 08/10] =?UTF-8?q?lora=E4=BB=A5=E5=A4=96=E3=82=82=E5=AF=BE?= =?UTF-8?q?=E5=BF=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 15 ++++++++------- library/train_util.py | 11 +++++++++++ train_db.py | 15 ++++++++------- train_network.py | 13 +------------ train_textual_inversion.py | 16 +++++++++------- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index ff580435..96ec3d96 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -6,6 +6,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -21,10 +22,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) @@ -65,6 +62,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -188,7 +189,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -259,14 +260,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 for m in training_models: m.train() loss_total = 0 for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: diff --git a/library/train_util.py b/library/train_util.py index 223e403b..994201fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2987,3 +2987,14 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion + +# colalte_fn用 epoch,stepはmultiprocessing.Value +class collater_class: + def __init__(self,epoch,step): + self.current_epoch=epoch + self.current_step=step + def __call__(self, examples): + dataset = torch.utils.data.get_worker_info().dataset + dataset.set_current_epoch(self.current_epoch.value) + dataset.set_current_step(self.current_step.value) + return examples[0] \ No newline at end of file diff --git a/train_db.py b/train_db.py index 87fe771b..50d50345 100644 --- a/train_db.py +++ b/train_db.py @@ -8,6 +8,7 @@ import itertools import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -23,10 +24,6 @@ from library.config_util import ( ) -def collate_fn(examples): - return examples[0] - - def train(args): train_util.verify_training_args(args) train_util.prepare_dataset_args(args, False) @@ -60,6 +57,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -153,7 +154,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -233,8 +234,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -243,6 +243,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: print(f"stop text encoder training at step {global_step}") diff --git a/train_network.py b/train_network.py index dd1fb748..79d118d0 100644 --- a/train_network.py +++ b/train_network.py @@ -25,17 +25,6 @@ from library.config_util import ( BlueprintGenerator, ) -class collater_class: - def __init__(self,epoch,step): - self.current_epoch=epoch - self.current_step=step - def __call__(self, examples): - dataset = torch.utils.data.get_worker_info().dataset - dataset.set_current_epoch(self.current_epoch.value) - dataset.set_current_step(self.current_step.value) - return examples[0] - - # TODO 他のスクリプトと共通化する def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -110,7 +99,7 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = collater_class(current_epoch,current_step) + collater = train_util.collater_class(current_epoch,current_step) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 63b63426..4f2e2724 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -4,6 +4,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -71,10 +72,6 @@ imagenet_style_templates_small = [ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -186,6 +183,10 @@ def train(args): config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + current_epoch = Value('i',0) + current_step = Value('i',0) + collater = train_util.collater_class(current_epoch,current_step) + # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: print("use template for training captions. is object: {args.use_object_template}") @@ -251,7 +252,7 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) @@ -335,13 +336,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) - train_dataset_group.set_current_step(global_step) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: From 5a3d564a3028057b1d7671b4a570bd37f13aa8d6 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 02:26:08 +0900 Subject: [PATCH 09/10] =?UTF-8?q?print=E5=89=8A=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- library/train_util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 994201fc..1ba38adc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -517,7 +517,6 @@ class BaseDataset(torch.utils.data.Dataset): else: caption = caption.replace(str_from, str_to) - print(self.current_step, self.max_train_steps, caption) return caption def get_input_ids(self, caption): From a4b34a9c3ce36c93ae12d5b4b98ea9f93e1606a3 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 26 Mar 2023 03:26:55 +0900 Subject: [PATCH 10/10] =?UTF-8?q?blueprint=5Fargs=5Fconflict=E3=81=AF?= =?UTF-8?q?=E4=B8=8D=E8=A6=81=E3=81=AA=E3=81=9F=E3=82=81=E5=89=8A=E9=99=A4?= =?UTF-8?q?=E3=80=81shuffle=E3=81=8C=E6=AF=8E=E5=9B=9E=E8=A1=8C=E3=82=8F?= =?UTF-8?q?=E3=82=8C=E3=82=8B=E4=B8=8D=E5=85=B7=E5=90=88=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fine_tune.py | 1 - library/config_util.py | 9 --------- library/train_util.py | 3 ++- train_db.py | 1 - train_network.py | 1 - train_textual_inversion.py | 1 - 6 files changed, 2 insertions(+), 14 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 96ec3d96..f387179a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -59,7 +59,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/library/config_util.py b/library/config_util.py index 6817d9a3..b1543f63 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -497,15 +497,6 @@ def load_user_config(file: str) -> dict: return config -def blueprint_args_conflict(args,blueprint:Blueprint): - # train_dataset_group.set_current_epoch()とtrain_dataset_group.set_current_step()がWorkerを生成するタイミングで適用される影響で、persistent_workers有効時はずっと一定になってしまうため無効にする - # for b in blueprint.dataset_group.datasets: - # for t in b.subsets: - # if args.persistent_data_loader_workers and (t.params.caption_dropout_every_n_epochs > 0 or t.params.token_warmup_step>0): - # print("Warning: %s: --persistent_data_loader_workers option is disabled because it conflicts with caption_dropout_every_n_epochs and token_wormup_step. / caption_dropout_every_n_epochs及びtoken_warmup_stepと競合するため、--persistent_data_loader_workersオプションは無効になります。"%(t.params.image_dir)) - # # args.persistent_data_loader_workers = False - return - # for config test if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/library/train_util.py b/library/train_util.py index 1ba38adc..af9637de 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -437,8 +437,9 @@ class BaseDataset(torch.utils.data.Dataset): self.replacements = {} def set_current_epoch(self, epoch): + if not self.current_epoch == epoch: + self.shuffle_buckets() self.current_epoch = epoch - self.shuffle_buckets() def set_current_step(self, step): self.current_step = step diff --git a/train_db.py b/train_db.py index 50d50345..3a3d2df8 100644 --- a/train_db.py +++ b/train_db.py @@ -54,7 +54,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_network.py b/train_network.py index 79d118d0..a7dbd374 100644 --- a/train_network.py +++ b/train_network.py @@ -94,7 +94,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 4f2e2724..149308b4 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -180,7 +180,6 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) - config_util.blueprint_args_conflict(args,blueprint) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) current_epoch = Value('i',0)