do not save cuda_rng_state if no cuda closes #390

This commit is contained in:
Kohya S
2023-05-07 10:23:25 +09:00
parent fdbdb4748a
commit e54b6311ef

View File

@@ -3344,7 +3344,7 @@ def sample_images(
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
with torch.no_grad(): with torch.no_grad():
with accelerator.autocast(): with accelerator.autocast():
@@ -3451,7 +3451,8 @@ def sample_images(
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
torch.cuda.set_rng_state(cuda_rng_state) if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)