Make grad_norm and combined_grad_norm None is not recording

This commit is contained in:
rockerBOO
2025-05-01 01:37:57 -04:00
parent 80320d21fe
commit f62c68df3c
2 changed files with 10 additions and 8 deletions

View File

@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor:
def grad_norms(self) -> Tensor | None:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
def weight_norms(self) -> Tensor:
def weight_norms(self) -> Tensor | None:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
def combined_weight_norms(self) -> Tensor:
def combined_weight_norms(self) -> Tensor | None:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
def load_weights(self, file):

View File

@@ -1444,8 +1444,10 @@ class NetworkTrainer:
else:
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
keys_scaled = None