add scaling to max norm

This commit is contained in:
Kohya S
2023-06-01 19:46:17 +09:00
parent a5c38e5d5b
commit f4c9276336
2 changed files with 41 additions and 31 deletions

View File

@@ -670,7 +670,7 @@ def train(args):
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms)
keys_scaled, mean_norm, maximum_norm = max_norm(network.state_dict(), args.scale_weight_norms, accelerator.device)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None