load models one by one

This commit is contained in:
Kohya S
2024-07-08 22:04:43 +09:00
parent c9de7c4e9a
commit 3ea4fce5e0
3 changed files with 236 additions and 47 deletions

View File

@@ -1,19 +1,17 @@
import argparse import argparse
import math import math
import os import os
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
import torch import torch
from safetensors.torch import save_file from safetensors.torch import save_file
from accelerate import Accelerator
from library import sd3_models, sd3_utils, train_util from library import sd3_models, sd3_utils, train_util
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from accelerate import init_empty_weights
from tqdm import tqdm
# from transformers import CLIPTokenizer # from transformers import CLIPTokenizer
# from library import model_util # from library import model_util
# , sdxl_model_util, train_util, sdxl_original_unet # , sdxl_model_util, train_util, sdxl_original_unet
@@ -28,50 +26,48 @@ logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision from .sdxl_train_util import match_mixed_precision
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ def load_target_model(
model_type: str,
args: argparse.Namespace,
state_dict: dict,
accelerator: Accelerator,
attn_mode: str,
model_dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> Union[
sd3_models.MMDiT, sd3_models.MMDiT,
Optional[sd3_models.SDClipModel], Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG], Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel], Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE, sd3_models.SDVAE,
]: ]:
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
for pi in range(accelerator.state.num_processes): for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index: if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( if model_type == "mmdit":
args.pretrained_model_name_or_path, model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
args.clip_l, elif model_type == "clip_l":
args.clip_g, model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
args.t5xxl, elif model_type == "clip_g":
args.vae, model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
attn_mode, elif model_type == "t5xxl":
accelerator.device if args.lowram else "cpu", model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
model_dtype, elif model_type == "vae":
args.disable_mmap_load_safetensors, model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
clip_dtype, else:
t5xxl_device, raise ValueError(f"Unknown model type: {model_type}")
t5xxl_dtype,
vae_dtype,
)
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram: if args.lowram:
if clip_l is not None: model = model.to(accelerator.device)
clip_l.to(accelerator.device)
if clip_g is not None:
clip_g.to(accelerator.device)
if t5xxl is not None:
t5xxl.to(accelerator.device)
vae.to(accelerator.device)
mmdit.to(accelerator.device)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
return mmdit, clip_l, clip_g, t5xxl, vae return model
def save_models( def save_models(

View File

@@ -20,6 +20,175 @@ from library import sdxl_model_util
# region models # region models
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load MMDiT
logger.info("Building MMDit")
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
logger.info(f"Loaded MMDiT: {info}")
return mmdit
def load_clip_l(
state_dict: Dict,
clip_l_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
):
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
return clip_l
def load_clip_g(
state_dict: Dict,
clip_g_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
):
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
return clip_g
def load_t5xxl(
state_dict: Dict,
t5xxl_path: Optional[str],
attn_mode: str,
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
):
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
t5xxl.to(dtype=dtype)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
return t5xxl
def load_vae(
state_dict: Dict,
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
):
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_safetensors(vae_path, device, disable_mmap)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
return vae
def load_models( def load_models(
ckpt_path: str, ckpt_path: str,
clip_l_path: str, clip_l_path: str,

View File

@@ -13,12 +13,12 @@ from tqdm import tqdm
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
init_ipex() init_ipex()
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DDPMScheduler from diffusers import DDPMScheduler
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils
from library.sdxl_train_util import match_mixed_precision
# , sdxl_model_util # , sdxl_model_util
@@ -189,18 +189,19 @@ def train(args):
assert ( assert (
attn_mode == "torch" attn_mode == "torch"
), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
# models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying.
mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}")
args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype device_to_load = accelerator.device if args.lowram else "cpu"
sd3_state_dict = sd3_utils.load_safetensors(
args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors
) )
assert clip_l is not None, "clip_l is required / clip_lは必須です"
assert clip_g is not None, "clip_g is required / clip_gは必須です"
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# 学習を準備する # load VAE for caching latents
vae: sd3_models.SDVAE = None
if cache_latents: if cache_latents:
vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load)
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
@@ -220,15 +221,25 @@ def train(args):
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
) )
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
vae.to("cpu") vae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# load clip_l, clip_g, t5xxl for caching text encoder outputs
# # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0.
# mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
# args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
# )
clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
assert clip_l is not None, "clip_l is required / clip_lは必須です"
assert clip_g is not None, "clip_g is required / clip_gは必須です"
t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load)
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()
train_mmdit = args.learning_rate != 0
train_clip_l = False train_clip_l = False
train_clip_g = False train_clip_g = False
train_t5xxl = False train_t5xxl = False
@@ -280,17 +291,30 @@ def train(args):
accelerator.is_main_process, accelerator.is_main_process,
args.text_encoder_batch_size, args.text_encoder_batch_size,
) )
# TODO we can delete text encoders after caching
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# load MMDIT
# if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32).
# by loading with model_dtype, we can reduce memory usage.
model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx)
mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load)
if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()
train_mmdit = args.learning_rate != 0
mmdit.requires_grad_(train_mmdit)
if not train_mmdit:
mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared
if not cache_latents: if not cache_latents:
# load VAE here if not cached
vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
mmdit.requires_grad_(train_mmdit)
if not train_mmdit:
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
training_models = [] training_models = []
params_to_optimize = [] params_to_optimize = []
# if train_unet: # if train_unet: