mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
load models one by one
This commit is contained in:
@@ -1,19 +1,17 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
from library import sd3_models, sd3_utils, train_util
|
from library import sd3_models, sd3_utils, train_util
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# from transformers import CLIPTokenizer
|
# from transformers import CLIPTokenizer
|
||||||
# from library import model_util
|
# from library import model_util
|
||||||
# , sdxl_model_util, train_util, sdxl_original_unet
|
# , sdxl_model_util, train_util, sdxl_original_unet
|
||||||
@@ -28,50 +26,48 @@ logger = logging.getLogger(__name__)
|
|||||||
from .sdxl_train_util import match_mixed_precision
|
from .sdxl_train_util import match_mixed_precision
|
||||||
|
|
||||||
|
|
||||||
def load_target_model(args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype) -> Tuple[
|
def load_target_model(
|
||||||
|
model_type: str,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
state_dict: dict,
|
||||||
|
accelerator: Accelerator,
|
||||||
|
attn_mode: str,
|
||||||
|
model_dtype: Optional[torch.dtype],
|
||||||
|
device: Optional[torch.device],
|
||||||
|
) -> Union[
|
||||||
sd3_models.MMDiT,
|
sd3_models.MMDiT,
|
||||||
Optional[sd3_models.SDClipModel],
|
Optional[sd3_models.SDClipModel],
|
||||||
Optional[sd3_models.SDXLClipG],
|
Optional[sd3_models.SDXLClipG],
|
||||||
Optional[sd3_models.T5XXLModel],
|
Optional[sd3_models.T5XXLModel],
|
||||||
sd3_models.SDVAE,
|
sd3_models.SDVAE,
|
||||||
]:
|
]:
|
||||||
model_dtype = match_mixed_precision(args, weight_dtype) # prepare fp16/bf16, None or fp16/bf16
|
loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
|
||||||
|
|
||||||
for pi in range(accelerator.state.num_processes):
|
for pi in range(accelerator.state.num_processes):
|
||||||
if pi == accelerator.state.local_process_index:
|
if pi == accelerator.state.local_process_index:
|
||||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||||
|
|
||||||
mmdit, clip_l, clip_g, t5xxl, vae = sd3_utils.load_models(
|
if model_type == "mmdit":
|
||||||
args.pretrained_model_name_or_path,
|
model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
|
||||||
args.clip_l,
|
elif model_type == "clip_l":
|
||||||
args.clip_g,
|
model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
|
||||||
args.t5xxl,
|
elif model_type == "clip_g":
|
||||||
args.vae,
|
model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
|
||||||
attn_mode,
|
elif model_type == "t5xxl":
|
||||||
accelerator.device if args.lowram else "cpu",
|
model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
|
||||||
model_dtype,
|
elif model_type == "vae":
|
||||||
args.disable_mmap_load_safetensors,
|
model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
|
||||||
clip_dtype,
|
else:
|
||||||
t5xxl_device,
|
raise ValueError(f"Unknown model type: {model_type}")
|
||||||
t5xxl_dtype,
|
|
||||||
vae_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
|
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
|
||||||
if args.lowram:
|
if args.lowram:
|
||||||
if clip_l is not None:
|
model = model.to(accelerator.device)
|
||||||
clip_l.to(accelerator.device)
|
|
||||||
if clip_g is not None:
|
|
||||||
clip_g.to(accelerator.device)
|
|
||||||
if t5xxl is not None:
|
|
||||||
t5xxl.to(accelerator.device)
|
|
||||||
vae.to(accelerator.device)
|
|
||||||
mmdit.to(accelerator.device)
|
|
||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
return mmdit, clip_l, clip_g, t5xxl, vae
|
return model
|
||||||
|
|
||||||
|
|
||||||
def save_models(
|
def save_models(
|
||||||
|
|||||||
@@ -20,6 +20,175 @@ from library import sdxl_model_util
|
|||||||
# region models
|
# region models
|
||||||
|
|
||||||
|
|
||||||
|
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
|
||||||
|
if disable_mmap:
|
||||||
|
return safetensors.torch.load(open(path, "rb").read())
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return load_file(path, device=dvc)
|
||||||
|
except:
|
||||||
|
return load_file(path) # prevent device invalid Error
|
||||||
|
|
||||||
|
|
||||||
|
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
|
||||||
|
mmdit_sd = {}
|
||||||
|
|
||||||
|
mmdit_prefix = "model.diffusion_model."
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith(mmdit_prefix):
|
||||||
|
mmdit_sd[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||||
|
|
||||||
|
# load MMDiT
|
||||||
|
logger.info("Building MMDit")
|
||||||
|
with init_empty_weights():
|
||||||
|
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||||
|
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
|
||||||
|
logger.info(f"Loaded MMDiT: {info}")
|
||||||
|
return mmdit
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_l(
|
||||||
|
state_dict: Dict,
|
||||||
|
clip_l_path: Optional[str],
|
||||||
|
attn_mode: str,
|
||||||
|
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||||
|
device: Union[str, torch.device],
|
||||||
|
disable_mmap: bool = False,
|
||||||
|
):
|
||||||
|
clip_l_sd = None
|
||||||
|
if clip_l_path:
|
||||||
|
logger.info(f"Loading clip_l from {clip_l_path}...")
|
||||||
|
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
|
||||||
|
for key in list(clip_l_sd.keys()):
|
||||||
|
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||||
|
else:
|
||||||
|
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||||
|
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||||
|
logger.info("clip_l is included in the checkpoint")
|
||||||
|
clip_l_sd = {}
|
||||||
|
prefix = "text_encoders.clip_l."
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||||
|
|
||||||
|
if clip_l_sd is None:
|
||||||
|
clip_l = None
|
||||||
|
else:
|
||||||
|
logger.info("Building ClipL")
|
||||||
|
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = clip_l.load_state_dict(clip_l_sd)
|
||||||
|
logger.info(f"Loaded ClipL: {info}")
|
||||||
|
clip_l.set_attn_mode(attn_mode)
|
||||||
|
return clip_l
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_g(
|
||||||
|
state_dict: Dict,
|
||||||
|
clip_g_path: Optional[str],
|
||||||
|
attn_mode: str,
|
||||||
|
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||||
|
device: Union[str, torch.device],
|
||||||
|
disable_mmap: bool = False,
|
||||||
|
):
|
||||||
|
clip_g_sd = None
|
||||||
|
if clip_g_path:
|
||||||
|
logger.info(f"Loading clip_g from {clip_g_path}...")
|
||||||
|
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
|
||||||
|
for key in list(clip_g_sd.keys()):
|
||||||
|
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||||
|
else:
|
||||||
|
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||||
|
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||||
|
logger.info("clip_g is included in the checkpoint")
|
||||||
|
clip_g_sd = {}
|
||||||
|
prefix = "text_encoders.clip_g."
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||||
|
|
||||||
|
if clip_g_sd is None:
|
||||||
|
clip_g = None
|
||||||
|
else:
|
||||||
|
logger.info("Building ClipG")
|
||||||
|
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = clip_g.load_state_dict(clip_g_sd)
|
||||||
|
logger.info(f"Loaded ClipG: {info}")
|
||||||
|
clip_g.set_attn_mode(attn_mode)
|
||||||
|
return clip_g
|
||||||
|
|
||||||
|
|
||||||
|
def load_t5xxl(
|
||||||
|
state_dict: Dict,
|
||||||
|
t5xxl_path: Optional[str],
|
||||||
|
attn_mode: str,
|
||||||
|
dtype: Optional[Union[str, torch.dtype]],
|
||||||
|
device: Union[str, torch.device],
|
||||||
|
disable_mmap: bool = False,
|
||||||
|
):
|
||||||
|
t5xxl_sd = None
|
||||||
|
if t5xxl_path:
|
||||||
|
logger.info(f"Loading t5xxl from {t5xxl_path}...")
|
||||||
|
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
|
||||||
|
for key in list(t5xxl_sd.keys()):
|
||||||
|
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||||
|
else:
|
||||||
|
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||||
|
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||||
|
logger.info("t5xxl is included in the checkpoint")
|
||||||
|
t5xxl_sd = {}
|
||||||
|
prefix = "text_encoders.t5xxl."
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith(prefix):
|
||||||
|
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||||
|
|
||||||
|
if t5xxl_sd is None:
|
||||||
|
t5xxl = None
|
||||||
|
else:
|
||||||
|
logger.info("Building T5XXL")
|
||||||
|
|
||||||
|
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
|
||||||
|
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
|
||||||
|
t5xxl.to(dtype=dtype)
|
||||||
|
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||||
|
logger.info(f"Loaded T5XXL: {info}")
|
||||||
|
t5xxl.set_attn_mode(attn_mode)
|
||||||
|
return t5xxl
|
||||||
|
|
||||||
|
|
||||||
|
def load_vae(
|
||||||
|
state_dict: Dict,
|
||||||
|
vae_path: Optional[str],
|
||||||
|
vae_dtype: Optional[Union[str, torch.dtype]],
|
||||||
|
device: Optional[Union[str, torch.device]],
|
||||||
|
disable_mmap: bool = False,
|
||||||
|
):
|
||||||
|
vae_sd = {}
|
||||||
|
if vae_path:
|
||||||
|
logger.info(f"Loading VAE from {vae_path}...")
|
||||||
|
vae_sd = load_safetensors(vae_path, device, disable_mmap)
|
||||||
|
else:
|
||||||
|
# remove prefix "first_stage_model."
|
||||||
|
vae_sd = {}
|
||||||
|
vae_prefix = "first_stage_model."
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
if k.startswith(vae_prefix):
|
||||||
|
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||||
|
|
||||||
|
logger.info("Building VAE")
|
||||||
|
vae = sd3_models.SDVAE()
|
||||||
|
logger.info("Loading state dict...")
|
||||||
|
info = vae.load_state_dict(vae_sd)
|
||||||
|
logger.info(f"Loaded VAE: {info}")
|
||||||
|
vae.to(device=device, dtype=vae_dtype)
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
def load_models(
|
def load_models(
|
||||||
ckpt_path: str,
|
ckpt_path: str,
|
||||||
clip_l_path: str,
|
clip_l_path: str,
|
||||||
|
|||||||
58
sd3_train.py
58
sd3_train.py
@@ -13,12 +13,12 @@ from tqdm import tqdm
|
|||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DDPMScheduler
|
from diffusers import DDPMScheduler
|
||||||
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils
|
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils
|
||||||
|
from library.sdxl_train_util import match_mixed_precision
|
||||||
|
|
||||||
# , sdxl_model_util
|
# , sdxl_model_util
|
||||||
|
|
||||||
@@ -189,18 +189,19 @@ def train(args):
|
|||||||
|
|
||||||
assert (
|
assert (
|
||||||
attn_mode == "torch"
|
attn_mode == "torch"
|
||||||
), f"attn_mode {attn_mode} is not supported. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
||||||
|
|
||||||
# models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0.
|
# SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying.
|
||||||
mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
|
logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}")
|
||||||
args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
|
device_to_load = accelerator.device if args.lowram else "cpu"
|
||||||
|
sd3_state_dict = sd3_utils.load_safetensors(
|
||||||
|
args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors
|
||||||
)
|
)
|
||||||
assert clip_l is not None, "clip_l is required / clip_lは必須です"
|
|
||||||
assert clip_g is not None, "clip_g is required / clip_gは必須です"
|
|
||||||
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
# 学習を準備する
|
# load VAE for caching latents
|
||||||
|
vae: sd3_models.SDVAE = None
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
|
vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load)
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
@@ -220,15 +221,25 @@ def train(args):
|
|||||||
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||||
)
|
)
|
||||||
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
|
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
|
||||||
vae.to("cpu")
|
vae.to("cpu") # if no sampling, vae can be deleted
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# load clip_l, clip_g, t5xxl for caching text encoder outputs
|
||||||
|
# # models are usually loaded on CPU and moved to GPU later. This is to avoid OOM on GPU0.
|
||||||
|
# mmdit, clip_l, clip_g, t5xxl, vae = sd3_train_utils.load_target_model(
|
||||||
|
# args, accelerator, attn_mode, weight_dtype, clip_dtype, t5xxl_device, t5xxl_dtype, vae_dtype
|
||||||
|
# )
|
||||||
|
clip_l = sd3_train_utils.load_target_model("clip_l", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
|
||||||
|
clip_g = sd3_train_utils.load_target_model("clip_g", args, sd3_state_dict, accelerator, attn_mode, clip_dtype, device_to_load)
|
||||||
|
assert clip_l is not None, "clip_l is required / clip_lは必須です"
|
||||||
|
assert clip_g is not None, "clip_g is required / clip_gは必須です"
|
||||||
|
|
||||||
|
t5xxl = sd3_train_utils.load_target_model("t5xxl", args, sd3_state_dict, accelerator, attn_mode, t5xxl_dtype, device_to_load)
|
||||||
|
# logit_scale = logit_scale.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
# 学習を準備する:モデルを適切な状態にする
|
||||||
if args.gradient_checkpointing:
|
|
||||||
mmdit.enable_gradient_checkpointing()
|
|
||||||
train_mmdit = args.learning_rate != 0
|
|
||||||
train_clip_l = False
|
train_clip_l = False
|
||||||
train_clip_g = False
|
train_clip_g = False
|
||||||
train_t5xxl = False
|
train_t5xxl = False
|
||||||
@@ -280,17 +291,30 @@ def train(args):
|
|||||||
accelerator.is_main_process,
|
accelerator.is_main_process,
|
||||||
args.text_encoder_batch_size,
|
args.text_encoder_batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO we can delete text encoders after caching
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
# load MMDIT
|
||||||
|
# if full_fp16/bf16, model_dtype is casted to fp16/bf16. If not, model_dtype is None (float32).
|
||||||
|
# by loading with model_dtype, we can reduce memory usage.
|
||||||
|
model_dtype = match_mixed_precision(args, weight_dtype) # None (default) or fp16/bf16 (full_xxxx)
|
||||||
|
mmdit = sd3_train_utils.load_target_model("mmdit", args, sd3_state_dict, accelerator, attn_mode, model_dtype, device_to_load)
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
mmdit.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
train_mmdit = args.learning_rate != 0
|
||||||
|
mmdit.requires_grad_(train_mmdit)
|
||||||
|
if not train_mmdit:
|
||||||
|
mmdit.to(accelerator.device, dtype=weight_dtype) # because of mmdie will not be prepared
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents:
|
||||||
|
# load VAE here if not cached
|
||||||
|
vae = sd3_train_utils.load_target_model("vae", args, sd3_state_dict, accelerator, attn_mode, vae_dtype, device_to_load)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
mmdit.requires_grad_(train_mmdit)
|
|
||||||
if not train_mmdit:
|
|
||||||
mmdit.to(accelerator.device, dtype=weight_dtype) # because of unet is not prepared
|
|
||||||
|
|
||||||
training_models = []
|
training_models = []
|
||||||
params_to_optimize = []
|
params_to_optimize = []
|
||||||
# if train_unet:
|
# if train_unet:
|
||||||
|
|||||||
Reference in New Issue
Block a user