From fcdae99d5c4fc07a5a2bcb7e843581f58ab3c146 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 30 Mar 2025 14:34:09 -0400 Subject: [PATCH 1/5] Add gradient noise scale logging --- library/train_util.py | 1 + networks/lora_flux.py | 29 +++++++++++++++++++++++++++++ train_network.py | 4 ++++ 3 files changed, 34 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 1ed1d3c2..e2f43dfa 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4125,6 +4125,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) + parser.add_argument("--gradient_noise_scale", action="store_true", default=False, help="Calculate the gradient noise scale") if support_dreambooth: # DreamBooth training diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..997857fc 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -293,6 +293,10 @@ class LoRAModule(torch.nn.Module): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + def accumulate_grad(self): + for param in self.parameters(): + if param.grad is not None: + self.all_grad.append(param.grad.view(-1)) @property def device(self): @@ -976,6 +980,31 @@ class LoRANetwork(torch.nn.Module): combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + def accumulate_grad(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.accumulate_grad() + + def all_grad(self): + all_grad = [] + for lora in self.text_encoder_loras + self.unet_loras: + all_grad.append(lora.all_grad) + + return torch.stack(all_grad) + + def gradient_noise_scale(self): + mean_grad = torch.mean(self.all_grads(), dim=0) + + # Calculate trace of covariance matrix + centered_grads = all_grads - mean_grad + trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1)) + + # Calculate norm of mean gradient squared + grad_norm_squared = torch.sum(mean_grad**2) + + # Calculate GNS using provided gradient norm squared + gradient_noise_scale = trace_cov / grad_norm_squared + + return gradient_noise_scale.item() def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": diff --git a/train_network.py b/train_network.py index f66cdeb4..4d09e7ce 100644 --- a/train_network.py +++ b/train_network.py @@ -1418,6 +1418,8 @@ class NetworkTrainer: network.update_grad_norms() if hasattr(network, "update_norms"): network.update_norms() + if args.gradient_noise_scale and hasattr(network, "accumulate_grad"): + network.accumulate_grad() optimizer.step() lr_scheduler.step() @@ -1491,6 +1493,8 @@ class NetworkTrainer: mean_grad_norm, mean_combined_norm, ) + if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"): + logs = {**logs, "grad/noise_scale": self.gradient_noise_scale()} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented From 90bcab09d88c5e67ca4acf77096fabe7b7b91485 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 30 Mar 2025 15:45:55 -0400 Subject: [PATCH 2/5] use network --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 4d09e7ce..c618297a 100644 --- a/train_network.py +++ b/train_network.py @@ -1494,7 +1494,7 @@ class NetworkTrainer: mean_combined_norm, ) if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"): - logs = {**logs, "grad/noise_scale": self.gradient_noise_scale()} + logs = {**logs, "grad/noise_scale": network.gradient_noise_scale()} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented From df8e1ac2f1fcf7f73de14c73cedf98e49791cd32 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 1 Apr 2025 22:55:52 -0400 Subject: [PATCH 3/5] Accumulate gradient sums --- networks/lora_flux.py | 66 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 13 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 997857fc..9fb7f5a2 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -106,6 +106,9 @@ class LoRAModule(torch.nn.Module): self.dropout = dropout self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.grad_count = 0 + self.sum_grads = None + self.sum_squared_grads = None self.ggpo_sigma = ggpo_sigma self.ggpo_beta = ggpo_beta @@ -296,7 +299,16 @@ class LoRAModule(torch.nn.Module): def accumulate_grad(self): for param in self.parameters(): if param.grad is not None: - self.all_grad.append(param.grad.view(-1)) + grad = param.grad.detach().flatten() + self.grad_count += grad.numel() + + # Update running sums + if self.sum_grads is None: + self.sum_grads = grad.sum() + self.sum_squared_grads = (grad**2).sum() + else: + self.sum_grads += grad.sum() + self.sum_squared_grads += (grad**2).sum() @property def device(self): @@ -984,26 +996,54 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.accumulate_grad() - def all_grad(self): - all_grad = [] + def sum_grads(self): + sum_grads = [] + sum_squared_grads = [] + count = 0 for lora in self.text_encoder_loras + self.unet_loras: - all_grad.append(lora.all_grad) + if lora.sum_grads is not None: + sum_grads.append(lora.sum_grads) + if lora.sum_grads is not None: + sum_squared_grads.append(lora.sum_squared_grads) + count += lora.grad_count - return torch.stack(all_grad) + return ( + torch.stack(sum_grads) if len(sum_grads) > 0 else torch.tensor([]), + torch.stack(sum_squared_grads) if len(sum_squared_grads) > 0 else torch.tensor([]), + count + ) def gradient_noise_scale(self): - mean_grad = torch.mean(self.all_grads(), dim=0) + sum_grads, sum_squared_grads, count = self.sum_grads() - # Calculate trace of covariance matrix - centered_grads = all_grads - mean_grad - trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1)) + if count == 0: + return None - # Calculate norm of mean gradient squared - grad_norm_squared = torch.sum(mean_grad**2) + # Calculate mean gradient and mean squared gradient + mean_grad = torch.mean(sum_grads / count, dim=0) + mean_squared_grad = torch.mean(sum_squared_grads / count, dim=0) + + # Variance = E[X²] - E[X]² + variance = mean_squared_grad - mean_grad**2 + + # GNS = trace(Σ) / ||μ||² + # trace(Σ) = sum of variances = count * variance (for uniform variance assumption) + trace_cov = count * variance + grad_norm_squared = count * mean_grad**2 - # Calculate GNS using provided gradient norm squared gradient_noise_scale = trace_cov / grad_norm_squared - + # mean_grad = torch.mean(all_grads, dim=0) + # + # # Calculate trace of covariance matrix + # centered_grads = all_grads - mean_grad + # trace_cov = torch.mean(torch.sum(centered_grads**2, dim=0)) + # + # # Calculate norm of mean gradient squared + # grad_norm_squared = torch.sum(mean_grad**2) + # + # # Calculate GNS using provided gradient norm squared + # gradient_noise_scale = trace_cov / grad_norm_squared + return gradient_noise_scale.item() def load_weights(self, file): From 11bdf9bd76fcdf97b32c5d301a466f5942d89617 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 3 Apr 2025 14:55:54 -0400 Subject: [PATCH 4/5] Add noise variance. Add critical batch size based on variable batch size --- networks/lora_flux.py | 4 ++-- train_network.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 9fb7f5a2..48829eea 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -1017,7 +1017,7 @@ class LoRANetwork(torch.nn.Module): sum_grads, sum_squared_grads, count = self.sum_grads() if count == 0: - return None + return None, None # Calculate mean gradient and mean squared gradient mean_grad = torch.mean(sum_grads / count, dim=0) @@ -1044,7 +1044,7 @@ class LoRANetwork(torch.nn.Module): # # Calculate GNS using provided gradient norm squared # gradient_noise_scale = trace_cov / grad_norm_squared - return gradient_noise_scale.item() + return gradient_noise_scale.item(), variance.item() def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": diff --git a/train_network.py b/train_network.py index c618297a..55be9601 100644 --- a/train_network.py +++ b/train_network.py @@ -1377,7 +1377,10 @@ class NetworkTrainer: skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 + batch_size = 0 for step, batch in enumerate(skipped_dataloader or train_dataloader): + current_batch_size = len(batch['network_multipliers']) + batch_size += current_batch_size current_step.value = global_step if initial_step > 0: initial_step -= 1 @@ -1494,7 +1497,9 @@ class NetworkTrainer: mean_combined_norm, ) if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"): - logs = {**logs, "grad/noise_scale": network.gradient_noise_scale()} + gns, variance = network.gradient_noise_scale() + if gns is not None and variance is not None: + logs = {**logs, "gns/gradient_noise_scale": gns, "gns/noise_variance": variance, "gns/critcal_batch_size": gns / batch_size} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented @@ -1561,6 +1566,9 @@ class NetworkTrainer: } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) + if accelerator.sync_gradients: + batch_size = 0 # reset batch size + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep From bf2e5abd2eb5b395c1cc2cc142bf64e8e0e764d6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 3 Apr 2025 17:57:28 -0400 Subject: [PATCH 5/5] Move batch size syncing outside validation --- train_network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_network.py b/train_network.py index 55be9601..885203cf 100644 --- a/train_network.py +++ b/train_network.py @@ -1566,9 +1566,6 @@ class NetworkTrainer: } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) - if accelerator.sync_gradients: - batch_size = 0 # reset batch size - restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep @@ -1576,6 +1573,9 @@ class NetworkTrainer: accelerator.unwrap_model(network).train() progress_bar.unpause() + if accelerator.sync_gradients: + batch_size = 0 # reset batch size + if global_step >= args.max_train_steps: break