From 11bdf9bd76fcdf97b32c5d301a466f5942d89617 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 3 Apr 2025 14:55:54 -0400 Subject: [PATCH] Add noise variance. Add critical batch size based on variable batch size --- networks/lora_flux.py | 4 ++-- train_network.py | 10 +++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 9fb7f5a2..48829eea 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -1017,7 +1017,7 @@ class LoRANetwork(torch.nn.Module): sum_grads, sum_squared_grads, count = self.sum_grads() if count == 0: - return None + return None, None # Calculate mean gradient and mean squared gradient mean_grad = torch.mean(sum_grads / count, dim=0) @@ -1044,7 +1044,7 @@ class LoRANetwork(torch.nn.Module): # # Calculate GNS using provided gradient norm squared # gradient_noise_scale = trace_cov / grad_norm_squared - return gradient_noise_scale.item() + return gradient_noise_scale.item(), variance.item() def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": diff --git a/train_network.py b/train_network.py index c618297a..55be9601 100644 --- a/train_network.py +++ b/train_network.py @@ -1377,7 +1377,10 @@ class NetworkTrainer: skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 + batch_size = 0 for step, batch in enumerate(skipped_dataloader or train_dataloader): + current_batch_size = len(batch['network_multipliers']) + batch_size += current_batch_size current_step.value = global_step if initial_step > 0: initial_step -= 1 @@ -1494,7 +1497,9 @@ class NetworkTrainer: mean_combined_norm, ) if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"): - logs = {**logs, "grad/noise_scale": network.gradient_noise_scale()} + gns, variance = network.gradient_noise_scale() + if gns is not None and variance is not None: + logs = {**logs, "gns/gradient_noise_scale": gns, "gns/noise_variance": variance, "gns/critcal_batch_size": gns / batch_size} self.step_logging(accelerator, logs, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented @@ -1561,6 +1566,9 @@ class NetworkTrainer: } self.step_logging(accelerator, logs, global_step, epoch=epoch + 1) + if accelerator.sync_gradients: + batch_size = 0 # reset batch size + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep