fix training scripts except controlnet not working

This commit is contained in:
Kohya S
2023-06-22 08:46:53 +09:00
parent 1c09867b3e
commit 5114e8daf1
7 changed files with 14 additions and 9 deletions

View File

@@ -42,7 +42,7 @@ def train(args):
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)
@@ -260,7 +260,9 @@ def train(args):
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}") accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") accelerator.print(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
)
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
@@ -395,7 +397,9 @@ def train(args):
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])} logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value if (
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy"
): # tracking d*lr value
logs["lr/d*lr"] = ( logs["lr/d*lr"] = (
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"] lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
) )

View File

@@ -584,6 +584,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--support_dreambooth", action="store_true") parser.add_argument("--support_dreambooth", action="store_true")
parser.add_argument("--support_finetuning", action="store_true") parser.add_argument("--support_finetuning", action="store_true")
parser.add_argument("--support_controlnet", action="store_true")
parser.add_argument("--support_dropout", action="store_true") parser.add_argument("--support_dropout", action="store_true")
parser.add_argument("dataset_config") parser.add_argument("dataset_config")
config_args, remain = parser.parse_known_args() config_args, remain = parser.parse_known_args()
@@ -602,7 +603,7 @@ if __name__ == "__main__":
print("\n[user_config]") print("\n[user_config]")
print(user_config) print(user_config)
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout) sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout)
sanitized_user_config = sanitizer.sanitize_user_config(user_config) sanitized_user_config = sanitizer.sanitize_user_config(user_config)
print("\n[sanitized_user_config]") print("\n[sanitized_user_config]")

View File

@@ -14,7 +14,7 @@ altair==4.2.2
easygui==0.98.3 easygui==0.98.3
toml==0.10.2 toml==0.10.2
voluptuous==0.13.1 voluptuous==0.13.1
huggingface-hub==0.13.3 huggingface-hub==0.14.1
# for BLIP captioning # for BLIP captioning
# requests==2.28.2 # requests==2.28.2
# timm==0.6.12 # timm==0.6.12

View File

@@ -45,7 +45,7 @@ def train(args):
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)

View File

@@ -92,7 +92,7 @@ def train(args):
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if use_user_config: if use_user_config:
print(f"Loading dataset config from {args.dataset_config}") print(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)

View File

@@ -151,7 +151,7 @@ def train(args):
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
if args.dataset_config is not None: if args.dataset_config is not None:
accelerator.print(f"Load dataset config from {args.dataset_config}") accelerator.print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)

View File

@@ -189,7 +189,7 @@ def train(args):
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}") print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
# データセットを準備する # データセットを準備する
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
if args.dataset_config is not None: if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}") print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config) user_config = config_util.load_user_config(args.dataset_config)