Add noise variance. Add critical batch size based on variable batch size

This commit is contained in:
rockerBOO
2025-04-03 14:55:54 -04:00
parent df8e1ac2f1
commit 11bdf9bd76
2 changed files with 11 additions and 3 deletions

View File

@@ -1017,7 +1017,7 @@ class LoRANetwork(torch.nn.Module):
sum_grads, sum_squared_grads, count = self.sum_grads()
if count == 0:
return None
return None, None
# Calculate mean gradient and mean squared gradient
mean_grad = torch.mean(sum_grads / count, dim=0)
@@ -1044,7 +1044,7 @@ class LoRANetwork(torch.nn.Module):
# # Calculate GNS using provided gradient norm squared
# gradient_noise_scale = trace_cov / grad_norm_squared
return gradient_noise_scale.item()
return gradient_noise_scale.item(), variance.item()
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":

View File

@@ -1377,7 +1377,10 @@ class NetworkTrainer:
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
initial_step = 1
batch_size = 0
for step, batch in enumerate(skipped_dataloader or train_dataloader):
current_batch_size = len(batch['network_multipliers'])
batch_size += current_batch_size
current_step.value = global_step
if initial_step > 0:
initial_step -= 1
@@ -1494,7 +1497,9 @@ class NetworkTrainer:
mean_combined_norm,
)
if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"):
logs = {**logs, "grad/noise_scale": network.gradient_noise_scale()}
gns, variance = network.gradient_noise_scale()
if gns is not None and variance is not None:
logs = {**logs, "gns/gradient_noise_scale": gns, "gns/noise_variance": variance, "gns/critcal_batch_size": gns / batch_size}
self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
@@ -1561,6 +1566,9 @@ class NetworkTrainer:
}
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
if accelerator.sync_gradients:
batch_size = 0 # reset batch size
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep