mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Add noise variance. Add critical batch size based on variable batch size
This commit is contained in:
@@ -1017,7 +1017,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
sum_grads, sum_squared_grads, count = self.sum_grads()
|
||||
|
||||
if count == 0:
|
||||
return None
|
||||
return None, None
|
||||
|
||||
# Calculate mean gradient and mean squared gradient
|
||||
mean_grad = torch.mean(sum_grads / count, dim=0)
|
||||
@@ -1044,7 +1044,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
# # Calculate GNS using provided gradient norm squared
|
||||
# gradient_noise_scale = trace_cov / grad_norm_squared
|
||||
|
||||
return gradient_noise_scale.item()
|
||||
return gradient_noise_scale.item(), variance.item()
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
|
||||
@@ -1377,7 +1377,10 @@ class NetworkTrainer:
|
||||
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1)
|
||||
initial_step = 1
|
||||
|
||||
batch_size = 0
|
||||
for step, batch in enumerate(skipped_dataloader or train_dataloader):
|
||||
current_batch_size = len(batch['network_multipliers'])
|
||||
batch_size += current_batch_size
|
||||
current_step.value = global_step
|
||||
if initial_step > 0:
|
||||
initial_step -= 1
|
||||
@@ -1494,7 +1497,9 @@ class NetworkTrainer:
|
||||
mean_combined_norm,
|
||||
)
|
||||
if args.gradient_noise_scale and hasattr(network, "gradient_noise_scale"):
|
||||
logs = {**logs, "grad/noise_scale": network.gradient_noise_scale()}
|
||||
gns, variance = network.gradient_noise_scale()
|
||||
if gns is not None and variance is not None:
|
||||
logs = {**logs, "gns/gradient_noise_scale": gns, "gns/noise_variance": variance, "gns/critcal_batch_size": gns / batch_size}
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
@@ -1561,6 +1566,9 @@ class NetworkTrainer:
|
||||
}
|
||||
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
batch_size = 0 # reset batch size
|
||||
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
|
||||
Reference in New Issue
Block a user