do not save cuda_rng_state if no cuda closes #390

This commit is contained in:
Kohya S
2023-05-07 10:23:25 +09:00
parent fdbdb4748a
commit e54b6311ef

View File

@@ -3344,7 +3344,7 @@ def sample_images(
os.makedirs(save_dir, exist_ok=True)
rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
with torch.no_grad():
with accelerator.autocast():
@@ -3451,6 +3451,7 @@ def sample_images(
torch.cuda.empty_cache()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)