diff --git a/train_network.py b/train_network.py index 2b4e6d3f..1336a0b1 100644 --- a/train_network.py +++ b/train_network.py @@ -1443,13 +1443,13 @@ class NetworkTrainer: max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm} else: if hasattr(network, "weight_norms"): - mean_norm = network.weight_norms().mean().item() + weight_norms = network.weight_norms() + mean_norm = weight_norms.mean().item() if weight_norms is not None else None grad_norms = network.grad_norms() mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None combined_weight_norms = network.combined_weight_norms() mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None - weight_norms = network.weight_norms() - maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None + maximum_norm = weight_norms.max().item() if weight_norms is not None else None keys_scaled = None max_mean_logs = {} else: