sd3 training

This commit is contained in:
Kohya S
2024-06-23 23:38:20 +09:00
parent a518e3c819
commit d53ea22b2a
8 changed files with 1909 additions and 44 deletions

View File

@@ -1,30 +1,226 @@
import math
from typing import Dict
from typing import Dict, Optional, Union
import torch
import safetensors
from safetensors.torch import load_file
from accelerate import init_empty_weights
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import sd3_models
# TODO move some of functions to model_util.py
from library import sdxl_model_util
# region models
def load_models(
ckpt_path: str,
clip_l_path: str,
clip_g_path: str,
t5xxl_path: str,
vae_path: str,
attn_mode: str,
device: Union[str, torch.device],
weight_dtype: torch.dtype,
disable_mmap: bool = False,
t5xxl_device: Optional[str] = None,
t5xxl_dtype: Optional[str] = None,
):
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
t5xxl_device = t5xxl_device or device
logger.info(f"Loading SD3 models from {ckpt_path}...")
state_dict = load_state_dict(ckpt_path)
# load clip_l
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_state_dict(clip_l_path)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
# load clip_g
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_state_dict(clip_g_path)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
# load t5xxl
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
# MMDiT and VAE
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_state_dict(vae_path)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
else:
state_dict.pop(k) # remove other keys
# load MMDiT
logger.info("Building MMDit")
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
logger.info(f"Loaded MMDiT: {info}")
# load ClipG and ClipL
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, weight_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, weight_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
# load T5XXL
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
# load VAE
logger.info("Building VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
return mmdit, clip_l, clip_g, t5xxl, vae
# endregion
# region utils
def get_cond(
prompt: str,
tokenizer: sd3_models.SD3Tokenizer,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: sd3_models.T5XXLModel,
t5xxl: Optional[sd3_models.T5XXLModel] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
def get_cond_from_tokens(
l_tokens,
g_tokens,
t5_tokens,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
l_out, l_pooled = clip_l.encode_token_weights(l_tokens)
g_out, g_pooled = clip_g.encode_token_weights(g_tokens)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if device is not None:
lg_out = lg_out.to(device=device)
l_pooled = l_pooled.to(device=device)
g_pooled = g_pooled.to(device=device)
if dtype is not None:
lg_out = lg_out.to(dtype=dtype)
l_pooled = l_pooled.to(dtype=dtype)
g_pooled = g_pooled.to(dtype=dtype)
# t5xxl may be in another device (eg. cpu)
if t5_tokens is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device)
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
else:
t5_out, t5_pooled = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
t5_out = t5_out.to(lg_out.dtype)
t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
if device is not None:
t5_out = t5_out.to(device=device)
if dtype is not None:
t5_out = t5_out.to(dtype=dtype)
return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
# return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1)
# used if other sd3 models is available
@@ -111,3 +307,6 @@ class ModelSamplingDiscreteFlow:
# assert max_denoise is False, "max_denoise not implemented"
# max_denoise is always True, I'm not sure why it's there
return sigma * noise + (1.0 - sigma) * latent_image
# endregion