mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Make grad_norm and combined_grad_norm None is not recording
This commit is contained in:
@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.update_grad_norms()
|
lora.update_grad_norms()
|
||||||
|
|
||||||
def grad_norms(self) -> Tensor:
|
def grad_norms(self) -> Tensor | None:
|
||||||
grad_norms = []
|
grad_norms = []
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||||
|
|
||||||
def weight_norms(self) -> Tensor:
|
def weight_norms(self) -> Tensor | None:
|
||||||
weight_norms = []
|
weight_norms = []
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||||
|
|
||||||
def combined_weight_norms(self) -> Tensor:
|
def combined_weight_norms(self) -> Tensor | None:
|
||||||
combined_weight_norms = []
|
combined_weight_norms = []
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||||
|
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
|
|||||||
@@ -1444,8 +1444,10 @@ class NetworkTrainer:
|
|||||||
else:
|
else:
|
||||||
if hasattr(network, "weight_norms"):
|
if hasattr(network, "weight_norms"):
|
||||||
mean_norm = network.weight_norms().mean().item()
|
mean_norm = network.weight_norms().mean().item()
|
||||||
mean_grad_norm = network.grad_norms().mean().item()
|
grad_norms = network.grad_norms()
|
||||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||||
|
combined_weight_norms = network.combined_weight_norms()
|
||||||
|
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||||
weight_norms = network.weight_norms()
|
weight_norms = network.weight_norms()
|
||||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||||
keys_scaled = None
|
keys_scaled = None
|
||||||
|
|||||||
Reference in New Issue
Block a user