This commit is contained in:
rockerBOO
2025-05-01 02:03:22 -04:00
parent f62c68df3c
commit b4a89c3cdf

View File

@@ -1443,13 +1443,13 @@ class NetworkTrainer:
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
weight_norms = network.weight_norms()
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
keys_scaled = None
max_mean_logs = {}
else: