mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
set python random state
This commit is contained in:
@@ -1278,7 +1278,7 @@ class NetworkTrainer:
|
||||
original_args_min_timestep = args.min_timestep
|
||||
original_args_max_timestep = args.max_timestep
|
||||
|
||||
def get_rng_state() -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
def switch_rng_state(seed:int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||
cpu_rng_state = torch.get_rng_state()
|
||||
if accelerator.device.type == "cuda":
|
||||
gpu_rng_state = torch.cuda.get_rng_state()
|
||||
@@ -1289,9 +1289,13 @@ class NetworkTrainer:
|
||||
else:
|
||||
gpu_rng_state = None
|
||||
python_rng_state = random.getstate()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
return (cpu_rng_state, gpu_rng_state, python_rng_state)
|
||||
|
||||
def set_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
|
||||
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
|
||||
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
|
||||
torch.set_rng_state(cpu_rng_state)
|
||||
if gpu_rng_state is not None:
|
||||
@@ -1416,8 +1420,7 @@ class NetworkTrainer:
|
||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_states = get_rng_state()
|
||||
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
@@ -1478,7 +1481,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
set_rng_state(rng_states)
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
@@ -1495,8 +1498,7 @@ class NetworkTrainer:
|
||||
if should_validate_epoch and len(val_dataloader) > 0:
|
||||
optimizer_eval_fn()
|
||||
accelerator.unwrap_model(network).eval()
|
||||
rng_states = get_rng_state()
|
||||
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||
|
||||
val_progress_bar = tqdm(
|
||||
range(validation_total_steps),
|
||||
@@ -1561,7 +1563,7 @@ class NetworkTrainer:
|
||||
}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
set_rng_state(rng_states)
|
||||
restore_rng_state(rng_states)
|
||||
args.min_timestep = original_args_min_timestep
|
||||
args.max_timestep = original_args_max_timestep
|
||||
optimizer_train_fn()
|
||||
|
||||
Reference in New Issue
Block a user