diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..0b30f1b8 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.update_grad_norms() - def grad_norms(self) -> Tensor: + def grad_norms(self) -> Tensor | None: grad_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "grad_norms") and lora.grad_norms is not None: grad_norms.append(lora.grad_norms.mean(dim=0)) - return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([]) + return torch.stack(grad_norms) if len(grad_norms) > 0 else None - def weight_norms(self) -> Tensor: + def weight_norms(self) -> Tensor | None: weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "weight_norms") and lora.weight_norms is not None: weight_norms.append(lora.weight_norms.mean(dim=0)) - return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([]) + return torch.stack(weight_norms) if len(weight_norms) > 0 else None - def combined_weight_norms(self) -> Tensor: + def combined_weight_norms(self) -> Tensor | None: combined_weight_norms = [] for lora in self.text_encoder_loras + self.unet_loras: if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None: combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) - return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None def load_weights(self, file): diff --git a/train_network.py b/train_network.py index d6bc66ed..2b4e6d3f 100644 --- a/train_network.py +++ b/train_network.py @@ -1444,8 +1444,10 @@ class NetworkTrainer: else: if hasattr(network, "weight_norms"): mean_norm = network.weight_norms().mean().item() - mean_grad_norm = network.grad_norms().mean().item() - mean_combined_norm = network.combined_weight_norms().mean().item() + grad_norms = network.grad_norms() + mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None + combined_weight_norms = network.combined_weight_norms() + mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None weight_norms = network.weight_norms() maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None keys_scaled = None