mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 08:21:46 +00:00
Fix _grad_magnitude_ema_up _grad_magnitude_ema_down getting saved to LoRA
This commit is contained in:
@@ -126,8 +126,8 @@ class LoRAModule(torch.nn.Module):
|
||||
self.mgpo_beta = mgpo_beta
|
||||
|
||||
# EMA of gradient magnitudes for adaptive normalization
|
||||
self._grad_magnitude_ema_down = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
self._grad_magnitude_ema_up = torch.nn.Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
self.register_buffer('_grad_magnitude_ema_down', torch.tensor(1.0), persistent=False)
|
||||
self.register_buffer('_grad_magnitude_ema_up', torch.tensor(1.0), persistent=False)
|
||||
|
||||
self.optimizer: torch.optim.Optimizer | None = None
|
||||
|
||||
@@ -337,24 +337,23 @@ class LoRAModule(torch.nn.Module):
|
||||
def update_gradient_ema(self):
|
||||
"""
|
||||
Update EMA of gradient magnitudes for adaptive perturbation normalization
|
||||
|
||||
Formula: ḡₗ⁽ᵗ⁾ = β * ḡₗ⁽ᵗ⁻¹⁾ + (1 - β) * ||∇ΔWₗL||₂
|
||||
"""
|
||||
if self.mgpo_beta is None:
|
||||
return
|
||||
|
||||
|
||||
# Update EMA for lora_down gradient magnitude
|
||||
if self.lora_down.weight.grad is not None:
|
||||
current_grad_norm = torch.norm(self.lora_down.weight.grad, p=2)
|
||||
self._grad_magnitude_ema_down.data = (
|
||||
self.mgpo_beta * self._grad_magnitude_ema_down.data + (1 - self.mgpo_beta) * current_grad_norm
|
||||
self._grad_magnitude_ema_down.mul_(self.mgpo_beta).add_(
|
||||
current_grad_norm, alpha=(1 - self.mgpo_beta)
|
||||
)
|
||||
|
||||
|
||||
# Update EMA for lora_up gradient magnitude
|
||||
if self.lora_up.weight.grad is not None:
|
||||
current_grad_norm = torch.norm(self.lora_up.weight.grad, p=2)
|
||||
self._grad_magnitude_ema_up.data = (
|
||||
self.mgpo_beta * self._grad_magnitude_ema_up.data + (1 - self.mgpo_beta) * current_grad_norm
|
||||
self._grad_magnitude_ema_up.mul_(self.mgpo_beta).add_(
|
||||
current_grad_norm, alpha=(1 - self.mgpo_beta)
|
||||
)
|
||||
|
||||
def get_mgpo_output_perturbation(self, x: Tensor) -> Tensor | None:
|
||||
|
||||
Reference in New Issue
Block a user