mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Make grad_norm and combined_grad_norm None is not recording
This commit is contained in:
@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor:
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
|
||||
def weight_norms(self) -> Tensor:
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
|
||||
@@ -1444,8 +1444,10 @@ class NetworkTrainer:
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
weight_norms = network.weight_norms()
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
keys_scaled = None
|
||||
|
||||
Reference in New Issue
Block a user