mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix training scripts except controlnet not working
This commit is contained in:
10
fine_tune.py
10
fine_tune.py
@@ -42,7 +42,7 @@ def train(args):
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, True, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
@@ -260,7 +260,9 @@ def train(args):
|
|||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
accelerator.print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
accelerator.print(
|
||||||
|
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
||||||
|
)
|
||||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
@@ -395,7 +397,9 @@ def train(args):
|
|||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||||
if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy": # tracking d*lr value
|
if (
|
||||||
|
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy"
|
||||||
|
): # tracking d*lr value
|
||||||
logs["lr/d*lr"] = (
|
logs["lr/d*lr"] = (
|
||||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -584,6 +584,7 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--support_dreambooth", action="store_true")
|
parser.add_argument("--support_dreambooth", action="store_true")
|
||||||
parser.add_argument("--support_finetuning", action="store_true")
|
parser.add_argument("--support_finetuning", action="store_true")
|
||||||
|
parser.add_argument("--support_controlnet", action="store_true")
|
||||||
parser.add_argument("--support_dropout", action="store_true")
|
parser.add_argument("--support_dropout", action="store_true")
|
||||||
parser.add_argument("dataset_config")
|
parser.add_argument("dataset_config")
|
||||||
config_args, remain = parser.parse_known_args()
|
config_args, remain = parser.parse_known_args()
|
||||||
@@ -602,7 +603,7 @@ if __name__ == "__main__":
|
|||||||
print("\n[user_config]")
|
print("\n[user_config]")
|
||||||
print(user_config)
|
print(user_config)
|
||||||
|
|
||||||
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_dropout)
|
sanitizer = ConfigSanitizer(config_args.support_dreambooth, config_args.support_finetuning, config_args.support_controlnet, config_args.support_dropout)
|
||||||
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
||||||
|
|
||||||
print("\n[sanitized_user_config]")
|
print("\n[sanitized_user_config]")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ altair==4.2.2
|
|||||||
easygui==0.98.3
|
easygui==0.98.3
|
||||||
toml==0.10.2
|
toml==0.10.2
|
||||||
voluptuous==0.13.1
|
voluptuous==0.13.1
|
||||||
huggingface-hub==0.13.3
|
huggingface-hub==0.14.1
|
||||||
# for BLIP captioning
|
# for BLIP captioning
|
||||||
# requests==2.28.2
|
# requests==2.28.2
|
||||||
# timm==0.6.12
|
# timm==0.6.12
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def train(args):
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, False, False, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ def train(args):
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
print(f"Loading dataset config from {args.dataset_config}")
|
print(f"Loading dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ def train(args):
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
accelerator.print(f"Load dataset config from {args.dataset_config}")
|
accelerator.print(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ def train(args):
|
|||||||
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, False))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user