diff --git a/train_network.py b/train_network.py index b3c7ff52..083e5993 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,7 +1278,7 @@ class NetworkTrainer: original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep - def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: cpu_rng_state = torch.get_rng_state() if accelerator.device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state() @@ -1289,9 +1289,13 @@ class NetworkTrainer: else: gpu_rng_state = None python_rng_state = random.getstate() + + torch.manual_seed(seed) + random.seed(seed) + return (cpu_rng_state, gpu_rng_state, python_rng_state) - def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): cpu_rng_state, gpu_rng_state, python_rng_state = rng_states torch.set_rng_state(cpu_rng_state) if gpu_rng_state is not None: @@ -1416,8 +1420,7 @@ class NetworkTrainer: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1478,7 +1481,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1495,8 +1498,7 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_states = get_rng_state() - torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) + rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( range(validation_total_steps), @@ -1561,7 +1563,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - set_rng_state(rng_states) + restore_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn()