From c5b803ce94bd70812e6979ac7b986a769659b14e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Feb 2025 21:59:09 +0900 Subject: [PATCH] rng state management: Implement functions to get and set RNG states for consistent validation --- train_network.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index f0deb67a..b3c7ff52 100644 --- a/train_network.py +++ b/train_network.py @@ -1278,6 +1278,31 @@ class NetworkTrainer: original_args_min_timestep = args.min_timestep original_args_max_timestep = args.max_timestep + def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]: + cpu_rng_state = torch.get_rng_state() + if accelerator.device.type == "cuda": + gpu_rng_state = torch.cuda.get_rng_state() + elif accelerator.device.type == "xpu": + gpu_rng_state = torch.xpu.get_rng_state() + elif accelerator.device.type == "mps": + gpu_rng_state = torch.cuda.get_rng_state() + else: + gpu_rng_state = None + python_rng_state = random.getstate() + return (cpu_rng_state, gpu_rng_state, python_rng_state) + + def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]): + cpu_rng_state, gpu_rng_state, python_rng_state = rng_states + torch.set_rng_state(cpu_rng_state) + if gpu_rng_state is not None: + if accelerator.device.type == "cuda": + torch.cuda.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "xpu": + torch.xpu.set_rng_state(gpu_rng_state) + elif accelerator.device.type == "mps": + torch.cuda.set_rng_state(gpu_rng_state) + random.setstate(python_rng_state) + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") current_epoch.value = epoch + 1 @@ -1391,7 +1416,7 @@ class NetworkTrainer: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1453,7 +1478,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn() @@ -1470,7 +1495,7 @@ class NetworkTrainer: if should_validate_epoch and len(val_dataloader) > 0: optimizer_eval_fn() accelerator.unwrap_model(network).eval() - rng_state = torch.get_rng_state() + rng_states = get_rng_state() torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed) val_progress_bar = tqdm( @@ -1536,7 +1561,7 @@ class NetworkTrainer: } accelerator.log(logs, step=global_step) - torch.set_rng_state(rng_state) + set_rng_state(rng_states) args.min_timestep = original_args_min_timestep args.max_timestep = original_args_max_timestep optimizer_train_fn()