mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
improve compatibility
This commit is contained in:
@@ -52,9 +52,9 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
||||
idx = 1
|
||||
|
||||
for i in range(idx, len(lrs)):
|
||||
logs[f"lr/block{i}"] = float(lrs[i])
|
||||
logs[f"lr/group{i}"] = float(lrs[i])
|
||||
if args.optimizer_type.lower() == "DAdaptation".lower():
|
||||
logs[f"lr/d*lr/block{i}"] = (
|
||||
logs[f"lr/d*lr/group{i}"] = (
|
||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||
)
|
||||
|
||||
@@ -194,6 +194,9 @@ def train(args):
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||
@@ -490,8 +493,6 @@ def train(args):
|
||||
# add extra args
|
||||
if args.network_args:
|
||||
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
||||
# for key, value in net_kwargs.items():
|
||||
# metadata["ss_arg_" + key] = value
|
||||
|
||||
# model name and hash
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
|
||||
Reference in New Issue
Block a user