mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #1054 from akx/mps
Device support improvements (MPS)
This commit is contained in:
@@ -20,7 +20,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -67,6 +66,7 @@ import library.sai_model_spec as sai_model_spec
|
||||
|
||||
# from library.attention_processors import FlashAttnProcessor
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.device_utils import clean_memory
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
@@ -2279,8 +2279,7 @@ def cache_batch_latents(
|
||||
info.latents_flipped = flipped_latent
|
||||
|
||||
# FIXME this slows down caching a lot, specify this as an option
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
@@ -3920,6 +3919,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
|
||||
|
||||
@@ -4006,8 +4006,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
@@ -4816,7 +4815,7 @@ def sample_images_common(
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
clean_memory()
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
if cuda_rng_state is not None:
|
||||
|
||||
Reference in New Issue
Block a user