From e4eb3e63e67038897840ebb8c2c4d781f8cfde60 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Apr 2023 07:48:48 +0900 Subject: [PATCH] improve compatibility --- train_network.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index c79b0922..07bf44bb 100644 --- a/train_network.py +++ b/train_network.py @@ -52,9 +52,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche idx = 1 for i in range(idx, len(lrs)): - logs[f"lr/block{i}"] = float(lrs[i]) + logs[f"lr/group{i}"] = float(lrs[i]) if args.optimizer_type.lower() == "DAdaptation".lower(): - logs[f"lr/d*lr/block{i}"] = ( + logs[f"lr/d*lr/group{i}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) @@ -193,6 +193,9 @@ def train(args): network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return + + if hasattr(network, "prepare_network"): + network.prepare_network(args) train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only @@ -490,8 +493,6 @@ def train(args): # add extra args if args.network_args: metadata["ss_network_args"] = json.dumps(net_kwargs) - # for key, value in net_kwargs.items(): - # metadata["ss_arg_" + key] = value # model name and hash if args.pretrained_model_name_or_path is not None: