mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
rng state management: Implement functions to get and set RNG states for consistent validation
This commit is contained in:
@@ -1278,6 +1278,31 @@ class NetworkTrainer:
|
|||||||
original_args_min_timestep = args.min_timestep
|
original_args_min_timestep = args.min_timestep
|
||||||
original_args_max_timestep = args.max_timestep
|
original_args_max_timestep = args.max_timestep
|
||||||
|
|
||||||
|
def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
if accelerator.device.type == "cuda":
|
||||||
|
gpu_rng_state = torch.cuda.get_rng_state()
|
||||||
|
elif accelerator.device.type == "xpu":
|
||||||
|
gpu_rng_state = torch.xpu.get_rng_state()
|
||||||
|
elif accelerator.device.type == "mps":
|
||||||
|
gpu_rng_state = torch.cuda.get_rng_state()
|
||||||
|
else:
|
||||||
|
gpu_rng_state = None
|
||||||
|
python_rng_state = random.getstate()
|
||||||
|
return (cpu_rng_state, gpu_rng_state, python_rng_state)
|
||||||
|
|
||||||
|
def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
|
||||||
|
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
|
||||||
|
torch.set_rng_state(cpu_rng_state)
|
||||||
|
if gpu_rng_state is not None:
|
||||||
|
if accelerator.device.type == "cuda":
|
||||||
|
torch.cuda.set_rng_state(gpu_rng_state)
|
||||||
|
elif accelerator.device.type == "xpu":
|
||||||
|
torch.xpu.set_rng_state(gpu_rng_state)
|
||||||
|
elif accelerator.device.type == "mps":
|
||||||
|
torch.cuda.set_rng_state(gpu_rng_state)
|
||||||
|
random.setstate(python_rng_state)
|
||||||
|
|
||||||
for epoch in range(epoch_to_start, num_train_epochs):
|
for epoch in range(epoch_to_start, num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
@@ -1391,7 +1416,7 @@ class NetworkTrainer:
|
|||||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
accelerator.unwrap_model(network).eval()
|
accelerator.unwrap_model(network).eval()
|
||||||
rng_state = torch.get_rng_state()
|
rng_states = get_rng_state()
|
||||||
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
@@ -1453,7 +1478,7 @@ class NetworkTrainer:
|
|||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
set_rng_state(rng_states)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
args.max_timestep = original_args_max_timestep
|
args.max_timestep = original_args_max_timestep
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
@@ -1470,7 +1495,7 @@ class NetworkTrainer:
|
|||||||
if should_validate_epoch and len(val_dataloader) > 0:
|
if should_validate_epoch and len(val_dataloader) > 0:
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
accelerator.unwrap_model(network).eval()
|
accelerator.unwrap_model(network).eval()
|
||||||
rng_state = torch.get_rng_state()
|
rng_states = get_rng_state()
|
||||||
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
@@ -1536,7 +1561,7 @@ class NetworkTrainer:
|
|||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
set_rng_state(rng_states)
|
||||||
args.min_timestep = original_args_min_timestep
|
args.min_timestep = original_args_min_timestep
|
||||||
args.max_timestep = original_args_max_timestep
|
args.max_timestep = original_args_max_timestep
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
|
|||||||
Reference in New Issue
Block a user