use same noise for every validation

This commit is contained in:
Kohya S
2025-01-27 22:10:38 +09:00
parent 42c0a9e1fc
commit 45ec02b2a8
2 changed files with 6 additions and 1 deletions

View File

@@ -1391,6 +1391,8 @@ class NetworkTrainer:
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_state = torch.get_rng_state()
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_total_steps),
@@ -1451,6 +1453,7 @@ class NetworkTrainer:
}
accelerator.log(logs, step=global_step)
torch.set_rng_state(rng_state)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
@@ -1467,6 +1470,8 @@ class NetworkTrainer:
if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_state = torch.get_rng_state()
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm(
range(validation_total_steps),
@@ -1531,6 +1536,7 @@ class NetworkTrainer:
}
accelerator.log(logs, step=global_step)
torch.set_rng_state(rng_state)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()