diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 9309ee30..98ee66bf 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,19 +1,17 @@ import argparse import math import os -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Union import torch from safetensors.torch import save_file +from accelerate import Accelerator from library import sd3_models, sd3_utils, train_util from library.device_utils import init_ipex, clean_memory_on_device init_ipex() -from accelerate import init_empty_weights -from tqdm import tqdm - # from transformers import CLIPTokenizer # from library import model_util # , sdxl_model_util, train_util, sdxl_original_unet @@ -28,50 +26,48 @@ logger = logging.getLogger(__name__) from .sdxl_train_util import match_mixed_precision -def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[ +def load_target_model( + model_type: str, + args: argparse.Namespace, + state_dict: dict, + accelerator: Accelerator, + attn_mode: str, + model_dtype: Optional[torch.dtype], + device: Optional[torch.device], +) -> Union[ sd3_models.MMDiT, Optional[sd3_models.SDClipModel], Optional[sd3_models.SDXLClipG], Optional[sd3_models.T5XXLModel], sd3_models.SDVAE, ]: - model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16 + loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu") for pi in range(accelerator.state.num_processes): if pi == accelerator.state.local_process_index: logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}") - mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models( - args.pretrained_model_name_or_path, - args.clip_l, - args.clip_g, - args.t5xxl, - args.vae, - attn_mode, - accelerator.device if args.lowram else "cpu", - model_dtype, - args.disable_mmap_load_safetensors, - clip_dtype, - t5xxl_device, - t5xxl_dtype, - vae_dtype, - ) + if model_type == "mmdit": + model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device) + elif model_type == "clip_l": + model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device) + elif model_type == "clip_g": + model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device) + elif model_type == "t5xxl": + model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device) + elif model_type == "vae": + model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device) + else: + raise ValueError(f"Unknown model type: {model_type}") # work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device if args.lowram: - if clip_l is not None: - clip_l.to(accelerator.device) - if clip_g is not None: - clip_g.to(accelerator.device) - if t5xxl is not None: - t5xxl.to(accelerator.device) - vae.to(accelerator.device) - mmdit.to(accelerator.device) + model = model.to(accelerator.device) clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() - return mmdit, clip_l, clip_g, t5xxl, vae + return model def save_models( diff --git a/library/sd3_utils.py b/library/sd3_utils.py index 9dc9e796..16f80c60 100644 --- a/library/sd3_utils.py +++ b/library/sd3_utils.py @@ -20,6 +20,175 @@ from library import sdxl_model_util # region models +def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False): + if disable_mmap: + return safetensors.torch.load(open(path, "rb").read()) + else: + try: + return load_file(path, device=dvc) + except: + return load_file(path) # prevent device invalid Error + + +def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]): + mmdit_sd = {} + + mmdit_prefix = "model.diffusion_model." + for k in list(state_dict.keys()): + if k.startswith(mmdit_prefix): + mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k) + + # load MMDiT + logger.info("Building MMDit") + with init_empty_weights(): + mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode) + + logger.info("Loading state dict...") + info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype) + logger.info(f"Loaded MMDiT: {info}") + return mmdit + + +def load_clip_l( + state_dict: Dict, + clip_l_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_l_sd = None + if clip_l_path: + logger.info(f"Loading clip_l from {clip_l_path}...") + clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap) + for key in list(clip_l_sd.keys()): + clip_l_sd["transformer." + key] = clip_l_sd.pop(key) + else: + if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_l: remove prefix "text_encoders.clip_l." + logger.info("clip_l is included in the checkpoint") + clip_l_sd = {} + prefix = "text_encoders.clip_l." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_l_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_l_sd is None: + clip_l = None + else: + logger.info("Building ClipL") + clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd) + logger.info("Loading state dict...") + info = clip_l.load_state_dict(clip_l_sd) + logger.info(f"Loaded ClipL: {info}") + clip_l.set_attn_mode(attn_mode) + return clip_l + + +def load_clip_g( + state_dict: Dict, + clip_g_path: Optional[str], + attn_mode: str, + clip_dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + clip_g_sd = None + if clip_g_path: + logger.info(f"Loading clip_g from {clip_g_path}...") + clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap) + for key in list(clip_g_sd.keys()): + clip_g_sd["transformer." + key] = clip_g_sd.pop(key) + else: + if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict: + # found clip_g: remove prefix "text_encoders.clip_g." + logger.info("clip_g is included in the checkpoint") + clip_g_sd = {} + prefix = "text_encoders.clip_g." + for k in list(state_dict.keys()): + if k.startswith(prefix): + clip_g_sd[k[len(prefix) :]] = state_dict.pop(k) + + if clip_g_sd is None: + clip_g = None + else: + logger.info("Building ClipG") + clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd) + logger.info("Loading state dict...") + info = clip_g.load_state_dict(clip_g_sd) + logger.info(f"Loaded ClipG: {info}") + clip_g.set_attn_mode(attn_mode) + return clip_g + + +def load_t5xxl( + state_dict: Dict, + t5xxl_path: Optional[str], + attn_mode: str, + dtype: Optional[Union[str, torch.dtype]], + device: Union[str, torch.device], + disable_mmap: bool = False, +): + t5xxl_sd = None + if t5xxl_path: + logger.info(f"Loading t5xxl from {t5xxl_path}...") + t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap) + for key in list(t5xxl_sd.keys()): + t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key) + else: + if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict: + # found t5xxl: remove prefix "text_encoders.t5xxl." + logger.info("t5xxl is included in the checkpoint") + t5xxl_sd = {} + prefix = "text_encoders.t5xxl." + for k in list(state_dict.keys()): + if k.startswith(prefix): + t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k) + + if t5xxl_sd is None: + t5xxl = None + else: + logger.info("Building T5XXL") + + # workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device + t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd) + t5xxl.to(dtype=dtype) + + logger.info("Loading state dict...") + info = t5xxl.load_state_dict(t5xxl_sd) + logger.info(f"Loaded T5XXL: {info}") + t5xxl.set_attn_mode(attn_mode) + return t5xxl + + +def load_vae( + state_dict: Dict, + vae_path: Optional[str], + vae_dtype: Optional[Union[str, torch.dtype]], + device: Optional[Union[str, torch.device]], + disable_mmap: bool = False, +): + vae_sd = {} + if vae_path: + logger.info(f"Loading VAE from {vae_path}...") + vae_sd = load_safetensors(vae_path, device, disable_mmap) + else: + # remove prefix "first_stage_model." + vae_sd = {} + vae_prefix = "first_stage_model." + for k in list(state_dict.keys()): + if k.startswith(vae_prefix): + vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k) + + logger.info("Building VAE") + vae = sd3_models.SDVAE() + logger.info("Loading state dict...") + info = vae.load_state_dict(vae_sd) + logger.info(f"Loaded VAE: {info}") + vae.to(device=device, dtype=vae_dtype) + return vae + + def load_models( ckpt_path: str, clip_l_path: str, diff --git a/sd3_train.py b/sd3_train.py index c073ec0e..10cc5d57 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -13,12 +13,12 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device - init_ipex() from accelerate.utils import set_seed from diffusers import DDPMScheduler from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils +from library.sdxl_train_util import match_mixed_precision # , sdxl_model_util @@ -189,18 +189,19 @@ def train(args): assert ( attn_mode == "torch" - ), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" + ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。" - # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. - mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( - args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying. + logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}") + device_to_load = accelerator.device if args.lowram else "cpu" + sd3_state_dict = sd3_utils.load_safetensors( + args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors ) - assert clip_l is not None, "clip_l is required / clip_lは必須です" - assert clip_g is not None, "clip_g is required / clip_gは必須です" - # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) - # 学習を準備する + # load VAE for caching latents + vae: sd3_models.SDVAE = None if cache_latents: + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() @@ -220,15 +221,25 @@ def train(args): vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check ) train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) - vae.to("cpu") + vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device) accelerator.wait_for_everyone() + # load clip_l, clip_g, t5xxl for caching text encoder outputs + # # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0. + # mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model( + # args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype + # ) + clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load) + assert clip_l is not None, "clip_l is required / clip_lは必須です" + assert clip_g is not None, "clip_g is required / clip_gは必須です" + + t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load) + # logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype) + # 学習を準備する:モデルを適切な状態にする - if args.gradient_checkpointing: - mmdit.enable_gradient_checkpointing() - train_mmdit = args.learning_rate != 0 train_clip_l = False train_clip_g = False train_t5xxl = False @@ -280,17 +291,30 @@ def train(args): accelerator.is_main_process, args.text_encoder_batch_size, ) + + # TODO we can delete text encoders after caching accelerator.wait_for_everyone() + # load MMDIT + # if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32). + # by loading with model_dtype, we can reduce memory usage. + model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx) + mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load) + if args.gradient_checkpointing: + mmdit.enable_gradient_checkpointing() + + train_mmdit = args.learning_rate != 0 + mmdit.requires_grad_(train_mmdit) + if not train_mmdit: + mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared + if not cache_latents: + # load VAE here if not cached + vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load) vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=vae_dtype) - mmdit.requires_grad_(train_mmdit) - if not train_mmdit: - mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared - training_models = [] params_to_optimize = [] # if train_unet: