From fcdae99d5c4fc07a5a2bcb7e843581f58ab3c146 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 30 Mar 2025 14:34:09 -0400 Subject: [PATCH] Add gradient noise scale logging --- library/train_util.py | 1 + networks/lora_flux.py | 29 +++++++++++++++++++++++++++++ train_network.py | 4 ++++ 3 files changed, 34 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 1ed1d3c2..e2f43dfa 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4125,6 +4125,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り", ) + parser.add_argument("--gradient_noise_scale", action="store_true", default=False, help="Calculate the gradient noise scale") if support_dreambooth: # DreamBooth training diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 92b3979a..997857fc 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -293,6 +293,10 @@ class LoRAModule(torch.nn.Module): approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight)) self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True) + def accumulate_grad(self): + for param in self.parameters(): + if param.grad is not None: + self.all_grad.append(param.grad.view(-1)) @property def device(self): @@ -976,6 +980,31 @@ class LoRANetwork(torch.nn.Module): combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0)) return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([]) + def accumulate_grad(self): + for lora in self.text_encoder_loras + self.unet_loras: + lora.accumulate_grad() + + def all_grad(self): + all_grad = [] + for lora in self.text_encoder_loras + self.unet_loras: + all_grad.append(lora.all_grad) + + return torch.stack(all_grad) + + def gradient_noise_scale(self): + mean_grad = torch.mean(self.all_grads(), dim=0) + + # Calculate trace of covariance matrix + centered_grads = all_grads - mean_grad + trace_cov = torch.mean(torch.sum(centered_grads**2, dim=1)) + + # Calculate norm of mean gradient squared + grad_norm_squared = torch.sum(mean_grad**2) + + # Calculate GNS using provided gradient norm squared + gradient_noise_scale = trace_cov / grad_norm_squared + + return gradient_noise_scale.item() def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": diff --git a/train_network.py b/train_network.py index f66cdeb4..4d09e7ce 100644 --- a/train_network.py +++ b/train_network.py @@ -1418,6 +1418,8 @@ class NetworkTrainer: network.update_grad_norms() if hasattr(network, "update_norms"): network.update_norms() + if args.gradient_noise_scale and hasattr(network, "accumulate_grad"): + network.accumulate_grad() optimizer.step() lr_scheduler.step() @@ -1491,6 +1493,8 @@ class NetworkTrainer: mean_grad_norm, mean_combined_norm, ) + if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"): + logs = {**logs, "grad/noise_scale": self.gradient_noise_scale()} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented