mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge pull request #676 from Isotr0py/sdxl
Fix RAM leak when loading SDXL model in lowram device
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils.modeling import set_module_tensor_to_device
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||||
|
from typing import List
|
||||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
|
||||||
from library import model_util
|
from library import model_util
|
||||||
from library import sdxl_original_unet
|
from library import sdxl_original_unet
|
||||||
@@ -133,13 +136,43 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
|||||||
return new_sd, logit_scale
|
return new_sd, logit_scale
|
||||||
|
|
||||||
|
|
||||||
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
def _load_state_dict(model, state_dict, device, dtype=None):
|
||||||
|
# dtype will use fp32 as default
|
||||||
|
missing_keys = list(model.state_dict().keys() - state_dict.keys())
|
||||||
|
unexpected_keys = list(state_dict.keys() - model.state_dict().keys())
|
||||||
|
|
||||||
|
# similar to model.load_state_dict()
|
||||||
|
if not missing_keys and not unexpected_keys:
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype)
|
||||||
|
return '<All keys matched successfully>'
|
||||||
|
|
||||||
|
# error_msgs
|
||||||
|
error_msgs: List[str] = []
|
||||||
|
if missing_keys:
|
||||||
|
error_msgs.insert(
|
||||||
|
0, 'Missing key(s) in state_dict: {}. '.format(
|
||||||
|
', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||||
|
if unexpected_keys:
|
||||||
|
error_msgs.insert(
|
||||||
|
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
||||||
|
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||||
|
|
||||||
|
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
||||||
|
model.__class__.__name__, "\n\t".join(error_msgs)))
|
||||||
|
|
||||||
|
|
||||||
|
def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None):
|
||||||
# model_version is reserved for future use
|
# model_version is reserved for future use
|
||||||
|
# dtype is reserved for full_fp16/bf16 integration
|
||||||
|
|
||||||
# Load the state dict
|
# Load the state dict
|
||||||
if model_util.is_safetensors(ckpt_path):
|
if model_util.is_safetensors(ckpt_path):
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
state_dict = load_file(ckpt_path, device=map_location)
|
try:
|
||||||
|
state_dict = load_file(ckpt_path, device=map_location)
|
||||||
|
except:
|
||||||
|
state_dict = load_file(ckpt_path) # prevent device invalid Error
|
||||||
epoch = None
|
epoch = None
|
||||||
global_step = None
|
global_step = None
|
||||||
else:
|
else:
|
||||||
@@ -156,16 +189,16 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location):
|
|||||||
|
|
||||||
# U-Net
|
# U-Net
|
||||||
print("building U-Net")
|
print("building U-Net")
|
||||||
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
with init_empty_weights():
|
||||||
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
|
|
||||||
print("loading U-Net from checkpoint")
|
print("loading U-Net from checkpoint")
|
||||||
unet_sd = {}
|
unet_sd = {}
|
||||||
for k in list(state_dict.keys()):
|
for k in list(state_dict.keys()):
|
||||||
if k.startswith("model.diffusion_model."):
|
if k.startswith("model.diffusion_model."):
|
||||||
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k)
|
||||||
info = unet.load_state_dict(unet_sd)
|
info = _load_state_dict(unet, unet_sd, device=map_location)
|
||||||
print("U-Net: ", info)
|
print("U-Net: ", info)
|
||||||
del unet_sd
|
|
||||||
|
|
||||||
# Text Encoders
|
# Text Encoders
|
||||||
print("building text encoders")
|
print("building text encoders")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
|
from accelerate import init_empty_weights
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
|
||||||
@@ -66,7 +67,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
unet,
|
unet,
|
||||||
logit_scale,
|
logit_scale,
|
||||||
ckpt_info,
|
ckpt_info,
|
||||||
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
|
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, weight_dtype)
|
||||||
else:
|
else:
|
||||||
# Diffusers model is loaded to CPU
|
# Diffusers model is loaded to CPU
|
||||||
from diffusers import StableDiffusionXLPipeline
|
from diffusers import StableDiffusionXLPipeline
|
||||||
@@ -75,7 +76,7 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, variant=variant, tokenizer=None)
|
pipe = StableDiffusionXLPipeline.from_pretrained(name_or_path, torch_dtype=weight_dtype, variant=variant, tokenizer=None)
|
||||||
except EnvironmentError as ex:
|
except EnvironmentError as ex:
|
||||||
if variant is not None:
|
if variant is not None:
|
||||||
print("try to load fp32 model")
|
print("try to load fp32 model")
|
||||||
@@ -95,10 +96,10 @@ def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version
|
|||||||
del pipe
|
del pipe
|
||||||
|
|
||||||
# Diffusers U-Net to original U-Net
|
# Diffusers U-Net to original U-Net
|
||||||
original_unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
|
||||||
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
state_dict = sdxl_model_util.convert_diffusers_unet_state_dict_to_sdxl(unet.state_dict())
|
||||||
original_unet.load_state_dict(state_dict)
|
with init_empty_weights():
|
||||||
unet = original_unet
|
unet = sdxl_original_unet.SdxlUNet2DConditionModel()
|
||||||
|
sdxl_model_util._load_state_dict(unet, state_dict, device=device)
|
||||||
print("U-Net converted to original U-Net")
|
print("U-Net converted to original U-Net")
|
||||||
|
|
||||||
logit_scale = None
|
logit_scale = None
|
||||||
|
|||||||
Reference in New Issue
Block a user