From df8e1ac2f1fcf7f73de14c73cedf98e49791cd32 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 1 Apr 2025 22:55:52 -0400 Subject: [PATCH] Accumulate gradient sums --- networks/lora_flux.py | 66 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 997857fc..9fb7f5a2 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -106,6 +106,9 @@ class LoRAModule(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.grad_count = 0 + self.sum_grads = None + self.sum_squared_grads = None self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -296,7 +299,16 @@ class LoRAModule(torch.nn.Module): def accumulate_grad(self): for param in self.parameters(): if param.grad is not None: - self.all_grad.append(param.grad.view(-1)) + grad = param.grad.detach().flatten() + self.grad_count += grad.numel() + + # Update running sums + if self.sum_grads is None: + self.sum_grads = grad.sum() + self.sum_squared_grads = (grad**2).sum() + else: + self.sum_grads += grad.sum() + self.sum_squared_grads += (grad**2).sum() @property def device(self): @@ -984,26 +996,54 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.accumulate_grad() - def all_grad(self): - all_grad = [] + def sum_grads(self): + sum_grads = [] + sum_squared_grads = [] + count = 0 for lora in self.text_encoder_loras + self.unet_loras: - all_grad.append(lora.all_grad) + if lora.sum_grads is not None: + sum_grads.append(lora.sum_grads) + if lora.sum_grads is not None: + sum_squared_grads.append(lora.sum_squared_grads) + count += lora.grad_count - return torch.stack(all_grad) + return ( + torch.stack(sum_grads) if len(sum_grads) > 0 else torch.tensor([]), + torch.stack(sum_squared_grads) if len(sum_squared_grads) > 0 else torch.tensor([]), + count + ) def gradient_noise_scale(self): - mean_grad = torch.mean(self.all_grads(), dim=0) + sum_grads, sum_squared_grads, count = self.sum_grads() - # Calculate trace of covariance matrix - centered_grads = all_grads - mean_grad - trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1)) + if count == 0: + return None - # Calculate norm of mean gradient squared - grad_norm_squared = torch.sum(mean_grad**2) + # Calculate mean gradient and mean squared gradient + mean_grad = torch.mean(sum_grads / count, dim=0) + mean_squared_grad = torch.mean(sum_squared_grads / count, dim=0) + + # Variance = E[X²] - E[X]² + variance = mean_squared_grad - mean_grad**2 + + # GNS = trace(Σ) / ||μ||² + # trace(Σ) = sum of variances = count * variance (for uniform variance assumption) + trace_cov = count * variance + grad_norm_squared = count * mean_grad**2 - # Calculate GNS using provided gradient norm squared gradient_noise_scale = trace_cov / grad_norm_squared - + # mean_grad = torch.mean(all_grads, dim=0) + # + # # Calculate trace of covariance matrix + # centered_grads = all_grads - mean_grad + # trace_cov = torch.mean(torch.sum(centered_grads**2, dim=0)) + # + # # Calculate norm of mean gradient squared + # grad_norm_squared = torch.sum(mean_grad**2) + # + # # Calculate GNS using provided gradient norm squared + # gradient_noise_scale = trace_cov / grad_norm_squared + return gradient_noise_scale.item() def load_weights(self, file):