This commit is contained in:
Dave Lage
2025-06-15 21:32:40 +03:00
committed by GitHub
3 changed files with 82 additions and 0 deletions

View File

@@ -4133,6 +4133,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument("--gradient_noise_scale", action="store_true", default=False, help="Calculate the gradient noise scale")
if support_dreambooth:
# DreamBooth training

View File

@@ -106,6 +106,9 @@ class LoRAModule(torch.nn.Module):
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.grad_count = 0
self.sum_grads = None
self.sum_squared_grads = None
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
@@ -293,6 +296,19 @@ class LoRAModule(torch.nn.Module):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
def accumulate_grad(self):
for param in self.parameters():
if param.grad is not None:
grad = param.grad.detach().flatten()
self.grad_count += grad.numel()
# Update running sums
if self.sum_grads is None:
self.sum_grads = grad.sum()
self.sum_squared_grads = (grad**2).sum()
else:
self.sum_grads += grad.sum()
self.sum_squared_grads += (grad**2).sum()
@property
def device(self):
@@ -976,6 +992,59 @@ class LoRANetwork(torch.nn.Module):
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
def accumulate_grad(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.accumulate_grad()
def sum_grads(self):
sum_grads = []
sum_squared_grads = []
count = 0
for lora in self.text_encoder_loras + self.unet_loras:
if lora.sum_grads is not None:
sum_grads.append(lora.sum_grads)
if lora.sum_grads is not None:
sum_squared_grads.append(lora.sum_squared_grads)
count += lora.grad_count
return (
torch.stack(sum_grads) if len(sum_grads) > 0 else torch.tensor([]),
torch.stack(sum_squared_grads) if len(sum_squared_grads) > 0 else torch.tensor([]),
count
)
def gradient_noise_scale(self):
sum_grads, sum_squared_grads, count = self.sum_grads()
if count == 0:
return None, None
# Calculate mean gradient and mean squared gradient
mean_grad = torch.mean(sum_grads / count, dim=0)
mean_squared_grad = torch.mean(sum_squared_grads / count, dim=0)
# Variance = E[X²] - E[X]²
variance = mean_squared_grad - mean_grad**2
# GNS = trace(Σ) / ||μ||²
# trace(Σ) = sum of variances = count * variance (for uniform variance assumption)
trace_cov = count * variance
grad_norm_squared = count * mean_grad**2
gradient_noise_scale = trace_cov / grad_norm_squared
# mean_grad = torch.mean(all_grads, dim=0)
#
# # Calculate trace of covariance matrix
# centered_grads = all_grads - mean_grad
# trace_cov = torch.mean(torch.sum(centered_grads**2, dim=0))
#
# # Calculate norm of mean gradient squared
# grad_norm_squared = torch.sum(mean_grad**2)
#
# # Calculate GNS using provided gradient norm squared
# gradient_noise_scale = trace_cov / grad_norm_squared
return gradient_noise_scale.item(), variance.item()
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":

View File

@@ -1388,7 +1388,10 @@ class NetworkTrainer:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1
batch_size = 0
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_batch_size = len(batch['network_multipliers'])
batch_size += current_batch_size
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
@@ -1429,6 +1432,8 @@ class NetworkTrainer:
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
if args.gradient_noise_scale and hasattr(network, "accumulate_grad"):
network.accumulate_grad()
optimizer.step()
lr_scheduler.step()
@@ -1504,6 +1509,10 @@ class NetworkTrainer:
mean_grad_norm,
mean_combined_norm,
)
if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"):
gns, variance = network.gradient_noise_scale()
if gns is not None and variance is not None:
logs = {**logs, "gns/gradient_noise_scale": gns, "gns/noise_variance": variance, "gns/critcal_batch_size": gns / batch_size}
self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
@@ -1577,6 +1586,9 @@ class NetworkTrainer:
accelerator.unwrap_model(network).train()
progress_bar.unpause()
if accelerator.sync_gradients:
batch_size = 0 # reset batch size
if global_step >= args.max_train_steps:
break