mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
use same noise for every validation
This commit is contained in:
@@ -1391,6 +1391,8 @@ class NetworkTrainer:
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_state = torch.get_rng_state()
|
||||
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
@@ -1451,6 +1453,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
@@ -1467,6 +1470,8 @@ class NetworkTrainer:
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_state = torch.get_rng_state()
|
||||
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
@@ -1531,6 +1536,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
|
||||
Reference in New Issue
Block a user