From 15136ca505fb392ec4deee6005571a3d1f95e87e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 6 Nov 2025 19:02:59 -0500 Subject: [PATCH] Fix _grad_magnitude_ema_up _grad_magnitude_ema_down getting saved to LoRA --- networks/lora_flux.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index f3ef301e..51a67e3f 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -126,8 +126,8 @@ class LoRAModule(torch.nn.Module): self.mgpo_beta = mgpo_beta # EMA of gradient magnitudes for adaptive normalization - self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) - self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False) + self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False) self.optimizer: torch.optim.Optimizer | None = None @@ -337,24 +337,23 @@ class LoRAModule(torch.nn.Module): def update_gradient_ema(self): """ Update EMA of gradient magnitudes for adaptive perturbation normalization - Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂ """ if self.mgpo_beta is None: return - + # Update EMA for lora_down gradient magnitude if self.lora_down.weight.grad is not None: current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2) - self._grad_magnitude_ema_down.data = ( - self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm + self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_( + current_grad_norm, alpha=(1 - self.mgpo_beta) ) - + # Update EMA for lora_up gradient magnitude if self.lora_up.weight.grad is not None: current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2) - self._grad_magnitude_ema_up.data = ( - self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm + self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_( + current_grad_norm, alpha=(1 - self.mgpo_beta) ) def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None: