Merge pull request #1054 from akx/mps

Device support improvements (MPS)
This commit is contained in:
Kohya S
2024-01-31 21:30:12 +09:00
committed by GitHub
22 changed files with 91 additions and 75 deletions

View File

@@ -20,7 +20,6 @@ from typing import (
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
import os
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.device_utils import clean_memory
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
@@ -2279,8 +2279,7 @@ def cache_batch_latents(
info.latents_flipped = flipped_latent
# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
torch.cuda.empty_cache()
clean_memory()
def cache_batch_text_encoder_outputs(
@@ -3920,6 +3919,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
)
print("accelerator device:", accelerator.device)
return accelerator
@@ -4006,8 +4006,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
unet.to(accelerator.device)
vae.to(accelerator.device)
gc.collect()
torch.cuda.empty_cache()
clean_memory()
accelerator.wait_for_everyone()
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4816,7 +4815,7 @@ def sample_images_common(
# clear pipeline and cache to reduce vram usage
del pipeline
torch.cuda.empty_cache()
clean_memory()
torch.set_rng_state(rng_state)
if cuda_rng_state is not None: