Fix _grad_magnitude_ema_up _grad_magnitude_ema_down getting saved to LoRA

This commit is contained in:
rockerBOO
2025-11-06 19:02:59 -05:00
parent 3f47806719
commit 15136ca505

View File

@@ -126,8 +126,8 @@ class LoRAModule(torch.nn.Module):
self.mgpo_beta = mgpo_beta
# EMA of gradient magnitudes for adaptive normalization
self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False)
self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False)
self.optimizer: torch.optim.Optimizer | None = None
@@ -337,24 +337,23 @@ class LoRAModule(torch.nn.Module):
def update_gradient_ema(self):
"""
Update EMA of gradient magnitudes for adaptive perturbation normalization
Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
"""
if self.mgpo_beta is None:
return
# Update EMA for lora_down gradient magnitude
if self.lora_down.weight.grad is not None:
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
self._grad_magnitude_ema_down.data = (
self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm
self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_(
current_grad_norm, alpha=(1 - self.mgpo_beta)
)
# Update EMA for lora_up gradient magnitude
if self.lora_up.weight.grad is not None:
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
self._grad_magnitude_ema_up.data = (
self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm
self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_(
current_grad_norm, alpha=(1 - self.mgpo_beta)
)
def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None: