mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
move max_norm to lora to avoid crashing in lycoris
This commit is contained in:
@@ -31,7 +31,6 @@ from library.custom_train_functions import (
|
||||
prepare_scheduler_for_custom_training,
|
||||
pyramid_noise_like,
|
||||
apply_noise_offset,
|
||||
max_norm,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
)
|
||||
|
||||
@@ -220,6 +219,11 @@ def train(args):
|
||||
|
||||
if hasattr(network, "prepare_network"):
|
||||
network.prepare_network(args)
|
||||
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
|
||||
print(
|
||||
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
|
||||
)
|
||||
args.scale_weight_norms = False
|
||||
|
||||
train_unet = not args.network_train_text_encoder_only
|
||||
train_text_encoder = not args.network_train_unet_only
|
||||
@@ -677,7 +681,9 @@ def train(args):
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device)
|
||||
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
keys_scaled, mean_norm, maximum_norm = None, None, None
|
||||
|
||||
Reference in New Issue
Block a user