mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Accumulate gradient sums
This commit is contained in:
@@ -106,6 +106,9 @@ class LoRAModule(torch.nn.Module):
|
||||
self.dropout = dropout
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
self.grad_count = 0
|
||||
self.sum_grads = None
|
||||
self.sum_squared_grads = None
|
||||
|
||||
self.ggpo_sigma = ggpo_sigma
|
||||
self.ggpo_beta = ggpo_beta
|
||||
@@ -296,7 +299,16 @@ class LoRAModule(torch.nn.Module):
|
||||
def accumulate_grad(self):
|
||||
for param in self.parameters():
|
||||
if param.grad is not None:
|
||||
self.all_grad.append(param.grad.view(-1))
|
||||
grad = param.grad.detach().flatten()
|
||||
self.grad_count += grad.numel()
|
||||
|
||||
# Update running sums
|
||||
if self.sum_grads is None:
|
||||
self.sum_grads = grad.sum()
|
||||
self.sum_squared_grads = (grad**2).sum()
|
||||
else:
|
||||
self.sum_grads += grad.sum()
|
||||
self.sum_squared_grads += (grad**2).sum()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
@@ -984,26 +996,54 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.accumulate_grad()
|
||||
|
||||
def all_grad(self):
|
||||
all_grad = []
|
||||
def sum_grads(self):
|
||||
sum_grads = []
|
||||
sum_squared_grads = []
|
||||
count = 0
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
all_grad.append(lora.all_grad)
|
||||
if lora.sum_grads is not None:
|
||||
sum_grads.append(lora.sum_grads)
|
||||
if lora.sum_grads is not None:
|
||||
sum_squared_grads.append(lora.sum_squared_grads)
|
||||
count += lora.grad_count
|
||||
|
||||
return torch.stack(all_grad)
|
||||
return (
|
||||
torch.stack(sum_grads) if len(sum_grads) > 0 else torch.tensor([]),
|
||||
torch.stack(sum_squared_grads) if len(sum_squared_grads) > 0 else torch.tensor([]),
|
||||
count
|
||||
)
|
||||
|
||||
def gradient_noise_scale(self):
|
||||
mean_grad = torch.mean(self.all_grads(), dim=0)
|
||||
sum_grads, sum_squared_grads, count = self.sum_grads()
|
||||
|
||||
# Calculate trace of covariance matrix
|
||||
centered_grads = all_grads - mean_grad
|
||||
trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1))
|
||||
if count == 0:
|
||||
return None
|
||||
|
||||
# Calculate norm of mean gradient squared
|
||||
grad_norm_squared = torch.sum(mean_grad**2)
|
||||
# Calculate mean gradient and mean squared gradient
|
||||
mean_grad = torch.mean(sum_grads / count, dim=0)
|
||||
mean_squared_grad = torch.mean(sum_squared_grads / count, dim=0)
|
||||
|
||||
# Variance = E[X²] - E[X]²
|
||||
variance = mean_squared_grad - mean_grad**2
|
||||
|
||||
# GNS = trace(Σ) / ||μ||²
|
||||
# trace(Σ) = sum of variances = count * variance (for uniform variance assumption)
|
||||
trace_cov = count * variance
|
||||
grad_norm_squared = count * mean_grad**2
|
||||
|
||||
# Calculate GNS using provided gradient norm squared
|
||||
gradient_noise_scale = trace_cov / grad_norm_squared
|
||||
|
||||
# mean_grad = torch.mean(all_grads, dim=0)
|
||||
#
|
||||
# # Calculate trace of covariance matrix
|
||||
# centered_grads = all_grads - mean_grad
|
||||
# trace_cov = torch.mean(torch.sum(centered_grads**2, dim=0))
|
||||
#
|
||||
# # Calculate norm of mean gradient squared
|
||||
# grad_norm_squared = torch.sum(mean_grad**2)
|
||||
#
|
||||
# # Calculate GNS using provided gradient norm squared
|
||||
# gradient_noise_scale = trace_cov / grad_norm_squared
|
||||
|
||||
return gradient_noise_scale.item()
|
||||
|
||||
def load_weights(self, file):
|
||||
|
||||
Reference in New Issue
Block a user