improve compatibility

This commit is contained in:
Kohya S
2023-04-04 07:48:48 +09:00
parent 83c7e03d05
commit e4eb3e63e6

View File

@@ -52,9 +52,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
idx = 1 idx = 1
for i in range(idx, len(lrs)): for i in range(idx, len(lrs)):
logs[f"lr/block{i}"] = float(lrs[i]) logs[f"lr/group{i}"] = float(lrs[i])
if args.optimizer_type.lower() == "DAdaptation".lower(): if args.optimizer_type.lower() == "DAdaptation".lower():
logs[f"lr/d*lr/block{i}"] = ( logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
) )
@@ -193,6 +193,9 @@ def train(args):
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None: if network is None:
return return
if hasattr(network, "prepare_network"):
network.prepare_network(args)
train_unet = not args.network_train_text_encoder_only train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only train_text_encoder = not args.network_train_unet_only
@@ -490,8 +493,6 @@ def train(args):
# add extra args # add extra args
if args.network_args: if args.network_args:
metadata["ss_network_args"] = json.dumps(net_kwargs) metadata["ss_network_args"] = json.dumps(net_kwargs)
# for key, value in net_kwargs.items():
# metadata["ss_arg_" + key] = value
# model name and hash # model name and hash
if args.pretrained_model_name_or_path is not None: if args.pretrained_model_name_or_path is not None: