From 5114e8daf14123f07710fd8eed6dd6baead9a572 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Jun 2023 08:46:53 +0900 Subject: [PATCH] fix training scripts except controlnet not working --- fine_tune.py | 10 +++++++--- library/config_util.py | 3 ++- requirements.txt | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 7 files changed, 14 insertions(+), 9 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 82e52dda..0afacdaf 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -42,7 +42,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True)) if args.dataset_config is not None: print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) @@ -260,7 +260,9 @@ def train(args): accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") - accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") + accelerator.print( + f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}" + ) accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") @@ -395,7 +397,9 @@ def train(args): current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず if args.logging_dir is not None: logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy" + ): # tracking d*lr value logs["lr/d*lr"] = ( lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] ) diff --git a/library/config_util.py b/library/config_util.py index ae17655c..36c165a5 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -584,6 +584,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--support_dreambooth", action="store_true") parser.add_argument("--support_finetuning", action="store_true") + parser.add_argument("--support_controlnet", action="store_true") parser.add_argument("--support_dropout", action="store_true") parser.add_argument("dataset_config") config_args, remain = parser.parse_known_args() @@ -602,7 +603,7 @@ if __name__ == "__main__": print("\n[user_config]") print(user_config) - sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) + sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout) sanitized_user_config = sanitizer.sanitize_user_config(user_config) print("\n[sanitized_user_config]") diff --git a/requirements.txt b/requirements.txt index a6bc8de5..74e06d21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.13.3 +huggingface-hub==0.14.1 # for BLIP captioning # requests==2.28.2 # timm==0.6.12 diff --git a/train_db.py b/train_db.py index ab094fdd..c8ddab1e 100644 --- a/train_db.py +++ b/train_db.py @@ -45,7 +45,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True)) if args.dataset_config is not None: print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) diff --git a/train_network.py b/train_network.py index 12e52248..7e930e8a 100644 --- a/train_network.py +++ b/train_network.py @@ -92,7 +92,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True)) if use_user_config: print(f"Loading dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 97f7435a..bcf0f196 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -151,7 +151,7 @@ def train(args): # データセットを準備する if args.dataset_class is None: - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: accelerator.print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 426799d1..3a87ede9 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -189,7 +189,7 @@ def train(args): print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") # データセットを準備する - blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) + blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False)) if args.dataset_config is not None: print(f"Load dataset config from {args.dataset_config}") user_config = config_util.load_user_config(args.dataset_config)