diff --git a/library/train_util.py b/library/train_util.py index 7b553363..ea433979 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3344,7 +3344,7 @@ def sample_images( os.makedirs(save_dir, exist_ok=True) rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state() + cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None with torch.no_grad(): with accelerator.autocast(): @@ -3451,7 +3451,8 @@ def sample_images( torch.cuda.empty_cache() torch.set_rng_state(rng_state) - torch.cuda.set_rng_state(cuda_rng_state) + if cuda_rng_state is not None: + torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device)