Merge pull request #1433 from millie-v/sample-image-without-cuda

Generate sample images without having CUDA (such as on Macs)
This commit is contained in:
Kohya S.
2024-09-07 10:19:55 +09:00
committed by GitHub

View File

@@ -5404,7 +5404,7 @@ def sample_images_common(
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
if cuda_rng_state is not None: if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)
@@ -5438,10 +5438,12 @@ def sample_image_inference(
if seed is not None: if seed is not None:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
else: else:
# True random sample image generation # True random sample image generation
torch.seed() torch.seed()
if torch.cuda.is_available():
torch.cuda.seed() torch.cuda.seed()
scheduler = get_my_scheduler( scheduler = get_my_scheduler(
@@ -5477,6 +5479,7 @@ def sample_image_inference(
controlnet_image=controlnet_image, controlnet_image=controlnet_image,
) )
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()): with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache() torch.cuda.empty_cache()