Accumulate gradient sums

This commit is contained in:
rockerBOO
2025-04-01 22:55:52 -04:00
parent 90bcab09d8
commit df8e1ac2f1

View File

@@ -106,6 +106,9 @@ class LoRAModule(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.grad_count = 0
self.sum_grads = None
self.sum_squared_grads = None
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -296,7 +299,16 @@ class LoRAModule(torch.nn.Module):
def accumulate_grad(self):
for param in self.parameters():
if param.grad is not None:
self.all_grad.append(param.grad.view(-1))
grad = param.grad.detach().flatten()
self.grad_count += grad.numel()
# Update running sums
if self.sum_grads is None:
self.sum_grads = grad.sum()
self.sum_squared_grads = (grad**2).sum()
else:
self.sum_grads += grad.sum()
self.sum_squared_grads += (grad**2).sum()
@property
def device(self):
@@ -984,26 +996,54 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.accumulate_grad()
def all_grad(self):
all_grad = []
def sum_grads(self):
sum_grads = []
sum_squared_grads = []
count = 0
for lora in self.text_encoder_loras + self.unet_loras:
all_grad.append(lora.all_grad)
if lora.sum_grads is not None:
sum_grads.append(lora.sum_grads)
if lora.sum_grads is not None:
sum_squared_grads.append(lora.sum_squared_grads)
count += lora.grad_count
return torch.stack(all_grad)
return (
torch.stack(sum_grads) if len(sum_grads) > 0 else torch.tensor([]),
torch.stack(sum_squared_grads) if len(sum_squared_grads) > 0 else torch.tensor([]),
count
)
def gradient_noise_scale(self):
mean_grad = torch.mean(self.all_grads(), dim=0)
sum_grads, sum_squared_grads, count = self.sum_grads()
# Calculate trace of covariance matrix
centered_grads = all_grads - mean_grad
trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1))
if count == 0:
return None
# Calculate norm of mean gradient squared
grad_norm_squared = torch.sum(mean_grad**2)
# Calculate mean gradient and mean squared gradient
mean_grad = torch.mean(sum_grads / count, dim=0)
mean_squared_grad = torch.mean(sum_squared_grads / count, dim=0)
# Variance = E[X²] - E[X]²
variance = mean_squared_grad - mean_grad**2
# GNS = trace(Σ) / ||μ||²
# trace(Σ) = sum of variances = count * variance (for uniform variance assumption)
trace_cov = count * variance
grad_norm_squared = count * mean_grad**2
# Calculate GNS using provided gradient norm squared
gradient_noise_scale = trace_cov / grad_norm_squared
# mean_grad = torch.mean(all_grads, dim=0)
#
# # Calculate trace of covariance matrix
# centered_grads = all_grads - mean_grad
# trace_cov = torch.mean(torch.sum(centered_grads**2, dim=0))
#
# # Calculate norm of mean gradient squared
# grad_norm_squared = torch.sum(mean_grad**2)
#
# # Calculate GNS using provided gradient norm squared
# gradient_noise_scale = trace_cov / grad_norm_squared
return gradient_noise_scale.item()
def load_weights(self, file):