Merge branch 'sd3' into network-wavelet-loss

This commit is contained in:
rockerBOO
2025-05-19 19:10:55 -04:00
6 changed files with 78 additions and 12 deletions

View File

@@ -1518,11 +1518,13 @@ class NetworkTrainer:
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
keys_scaled = None
max_mean_logs = {}
else: