move max_norm to lora to avoid crashing in lycoris

This commit is contained in:
Kohya S
2023-06-03 12:42:32 +09:00
parent 6084611508
commit 5bec05e045
4 changed files with 55 additions and 45 deletions

View File

@@ -31,7 +31,6 @@ from library.custom_train_functions import (
prepare_scheduler_for_custom_training,
pyramid_noise_like,
apply_noise_offset,
max_norm,
scale_v_prediction_loss_like_noise_prediction,
)
@@ -220,6 +219,11 @@ def train(args):
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
print(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
train_unet = not args.network_train_text_encoder_only
train_text_encoder = not args.network_train_unet_only
@@ -677,7 +681,9 @@ def train(args):
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device)
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None