refactor SD3 CLIP to transformers etc.

This commit is contained in:
Kohya S
2024-10-24 19:49:28 +09:00
parent 138dac4aea
commit 623017f716
13 changed files with 1201 additions and 2150 deletions

View File

@@ -29,7 +29,7 @@ init_ipex()
from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library.sd3_train_utils import load_prompts, FlowMatchEulerDiscreteScheduler
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
import library.train_util as train_util
@@ -241,7 +241,7 @@ def train(args):
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:

View File

@@ -231,7 +231,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
tokenize_strategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy: strategy_flux.FluxTextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
prompts = sd3_train_utils.load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:

View File

@@ -15,7 +15,6 @@ from PIL import Image
from safetensors.torch import save_file
from library import flux_models, flux_utils, strategy_base, train_util
from library.sd3_train_utils import load_prompts
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -70,7 +69,7 @@ def sample_images(
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)

View File

@@ -10,40 +10,21 @@ from safetensors import safe_open
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
from library import flux_models
from library.utils import setup_logging, MemoryEfficientSafeOpen
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
from library import flux_models
from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
# temporary copy from sd3_utils TODO refactor
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
):
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
return load_file(path, device=device)
except:
return load_file(path) # prevent device invalid Error
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
@@ -161,8 +142,14 @@ def load_ae(
return ae
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
logger.info("Building CLIP")
def load_clip_l(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> CLIPTextModel:
logger.info("Building CLIP-L")
CLIPL_CONFIG = {
"_name_or_path": "clip-vit-large-patch14/",
"architectures": ["CLIPModel"],
@@ -255,15 +242,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
with init_empty_weights():
clip = CLIPTextModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded CLIP: {info}")
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_t5xxl(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> T5EncoderModel:
T5_CONFIG_JSON = """
{
@@ -303,6 +297,9 @@ def load_t5xxl(
with init_empty_weights():
t5xxl = T5EncoderModel._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = t5xxl.load_state_dict(sd, strict=False, assign=True)

View File

@@ -57,8 +57,8 @@ ARCH_SD_V1 = "stable-diffusion-v1"
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3-medium"
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_UNKNOWN = "flux-1"
@@ -140,10 +140,7 @@ def build_metadata(
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
if sd3 == "m":
arch = ARCH_SD3_M
else:
arch = ARCH_SD3_UNKNOWN
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV

File diff suppressed because it is too large Load Diff

View File

@@ -11,8 +11,8 @@ from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_models, sd3_utils, strategy_base, train_util
from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
@@ -28,60 +28,16 @@ import logging
logger = logging.getLogger(__name__)
from .sdxl_train_util import match_mixed_precision
def load_target_model(
model_type: str,
args: argparse.Namespace,
state_dict: dict,
accelerator: Accelerator,
attn_mode: str,
model_dtype: Optional[torch.dtype],
device: Optional[torch.device],
) -> Union[
sd3_models.MMDiT,
Optional[sd3_models.SDClipModel],
Optional[sd3_models.SDXLClipG],
Optional[sd3_models.T5XXLModel],
sd3_models.SDVAE,
]:
loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
for pi in range(accelerator.state.num_processes):
if pi == accelerator.state.local_process_index:
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
if model_type == "mmdit":
model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
elif model_type == "clip_l":
model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
elif model_type == "clip_g":
model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
elif model_type == "t5xxl":
model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
elif model_type == "vae":
model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
else:
raise ValueError(f"Unknown model type: {model_type}")
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
if args.lowram:
model = model.to(accelerator.device)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
return model
from library import sd3_models, sd3_utils, strategy_base, train_util
def save_models(
ckpt_path: str,
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel],
mmdit: Optional[sd3_models.MMDiT],
vae: Optional[sd3_models.SDVAE],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
sai_metadata: Optional[dict],
save_dtype: Optional[torch.dtype] = None,
):
@@ -101,24 +57,35 @@ def save_models(
update_sd("model.diffusion_model.", mmdit.state_dict())
update_sd("first_stage_model.", vae.state_dict())
if clip_l is not None:
update_sd("text_encoders.clip_l.", clip_l.state_dict())
if clip_g is not None:
update_sd("text_encoders.clip_g.", clip_g.state_dict())
if t5xxl is not None:
update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
# do not support unified checkpoint format for now
# if clip_l is not None:
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
# if clip_g is not None:
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
# if t5xxl is not None:
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
save_file(state_dict, ckpt_path, metadata=sai_metadata)
if clip_l is not None:
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
save_file(clip_l.state_dict(), clip_l_path)
if clip_g is not None:
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
save_file(clip_g.state_dict(), clip_g_path)
if t5xxl is not None:
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
save_file(t5xxl.state_dict(), t5xxl_path)
def save_sd3_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
@@ -141,9 +108,9 @@ def save_sd3_model_on_epoch_end_or_stepwise(
epoch: int,
num_train_epochs: int,
global_step: int,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel],
clip_l: Optional[CLIPTextModelWithProjection],
clip_g: Optional[CLIPTextModelWithProjection],
t5xxl: Optional[T5EncoderModel],
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
):
@@ -208,23 +175,27 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
)
parser.add_argument(
"--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する"
"--save_clip",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する"
"--save_t5xxl",
action="store_true",
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
)
parser.add_argument(
"--t5xxl_device",
type=str,
default=None,
help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
)
parser.add_argument(
"--t5xxl_dtype",
type=str,
default=None,
help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtypemixed precisionからを使用",
)
# copy from Diffusers
@@ -233,16 +204,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd",
)
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
)
@@ -283,7 +263,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
# temporary copied from sd3_minimal_inferece.py
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
start = sampling.timestep(sampling.sigma_max)
end = sampling.timestep(sampling.sigma_min)
timesteps = torch.linspace(start, end, steps)
@@ -327,7 +307,7 @@ def do_sample(
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
sigmas = get_sigmas(model_sampling, steps).to(device)
sigmas = get_all_sigmas(model_sampling, steps).to(device)
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
@@ -371,37 +351,6 @@ def do_sample(
return x
def load_prompts(prompt_file: str) -> List[Dict]:
# read prompts
if prompt_file.endswith(".txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif prompt_file.endswith(".toml"):
with open(prompt_file, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif prompt_file.endswith(".json"):
with open(prompt_file, "r", encoding="utf-8") as f:
prompts = json.load(f)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
from library.train_util import line_to_prompt_dict
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
return prompts
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
@@ -440,7 +389,7 @@ def sample_images(
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
prompts = load_prompts(args.sample_prompts)
prompts = train_util.load_prompts(args.sample_prompts)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
@@ -510,7 +459,7 @@ def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
mmdit: sd3_models.MMDiT,
text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]],
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
vae: sd3_models.SDVAE,
save_dir,
prompt_dict,
@@ -568,7 +517,7 @@ def sample_image_inference(
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt)
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
lg_out, t5_out, pooled = te_outputs
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# encode negative prompts
@@ -578,7 +527,7 @@ def sample_image_inference(
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt)
neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
lg_out, t5_out, pooled = neg_te_outputs
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# sample image
@@ -609,14 +558,9 @@ def sample_image_inference(
wandb_tracker = accelerator.get_tracker("wandb")
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(
image,
caption=prompt # positive prompt as a caption
)},
commit=False
)
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
# region Diffusers
@@ -886,4 +830,78 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input, timesteps, sigmas
# endregion

View File

@@ -1,9 +1,12 @@
from dataclasses import dataclass
import math
from typing import Dict, Optional, Union
import re
from typing import Dict, List, Optional, Union
import torch
import safetensors
from safetensors.torch import load_file
from accelerate import init_empty_weights
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
from .utils import setup_logging
@@ -19,18 +22,61 @@ from library import sdxl_model_util
# region models
# TODO remove dependency on flux_utils
from library.utils import load_safetensors
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
logger.info(f"Analyzing state dict state...")
# analyze configs
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
x_block_self_attn_layers = []
re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight")
for key in list(state_dict.keys()):
m = re_attn.match(key)
if m:
x_block_self_attn_layers.append(int(m.group(1)))
assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported"
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]
# only supports 3-5-large and 3-medium
if qk_norm is not None:
model_type = "3-5-large"
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
model_type = "3-medium"
params = sd3_models.SD3Params(
patch_size=patch_size,
depth=depth,
num_patches=num_patches,
pos_embed_max_size=pos_embed_max_size,
adm_in_channels=adm_in_channels,
qk_norm=qk_norm,
x_block_self_attn_layers=x_block_self_attn_layers,
context_embedder_in_features=context_embedder_in_features,
context_embedder_out_features=context_embedder_out_features,
model_type=model_type,
)
logger.info(f"Analyzed state dict state: {params}")
return params
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
def load_mmdit(
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
) -> sd3_models.MMDiT:
mmdit_sd = {}
mmdit_prefix = "model.diffusion_model."
@@ -40,8 +86,9 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
# load MMDiT
logger.info("Building MMDit")
params = analyze_state_dict_state(mmdit_sd)
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
@@ -50,20 +97,14 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
def load_clip_l(
state_dict: Dict,
clip_l_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if clip_l_path is None:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
@@ -72,34 +113,58 @@ def load_clip_l(
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_l_path is None:
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
return None
# load clip_l
logger.info("Building CLIP-L")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
max_position_embeddings=77,
hidden_act="quick_gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=768,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
return clip_l
logger.info(f"Loading state dict from {clip_l_path}")
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
if "text_projection.weight" not in clip_l_sd:
logger.info("Adding text_projection.weight to clip_l_sd")
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-L: {info}")
return clip
def load_clip_g(
state_dict: Dict,
clip_g_path: Optional[str],
attn_mode: str,
clip_dtype: Optional[Union[str, torch.dtype]],
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if state_dict is not None:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
@@ -108,34 +173,53 @@ def load_clip_g(
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
elif clip_g_path is None:
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
return None
# load clip_g
logger.info("Building CLIP-G")
config = CLIPTextConfig(
vocab_size=49408,
hidden_size=1280,
intermediate_size=5120,
num_hidden_layers=32,
num_attention_heads=20,
max_position_embeddings=77,
hidden_act="gelu",
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type="clip_text_model",
projection_dim=1280,
# torch_dtype="float32",
# transformers_version="4.25.0.dev0",
)
with init_empty_weights():
clip = CLIPTextModelWithProjection(config)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
return clip_g
logger.info(f"Loading state dict from {clip_g_path}")
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
logger.info(f"Loaded CLIP-G: {info}")
return clip
def load_t5xxl(
state_dict: Dict,
t5xxl_path: Optional[str],
attn_mode: str,
dtype: Optional[Union[str, torch.dtype]],
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if state_dict is not None:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
@@ -144,29 +228,19 @@ def load_t5xxl(
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
elif t5xxl_path is None:
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
return None
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
t5xxl.to(dtype=dtype)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
return t5xxl
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
def load_vae(
state_dict: Dict,
vae_path: Optional[str],
vae_dtype: Optional[Union[str, torch.dtype]],
device: Optional[Union[str, torch.device]],
disable_mmap: bool = False,
state_dict: Optional[Dict] = None,
):
vae_sd = {}
if vae_path:
@@ -181,299 +255,15 @@ def load_vae(
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
logger.info("Building VAE")
vae = sd3_models.SDVAE()
vae = sd3_models.SDVAE(vae_dtype, device)
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
return vae
def load_models(
ckpt_path: str,
clip_l_path: str,
clip_g_path: str,
t5xxl_path: str,
vae_path: str,
attn_mode: str,
device: Union[str, torch.device],
weight_dtype: Optional[Union[str, torch.dtype]] = None,
disable_mmap: bool = False,
clip_dtype: Optional[Union[str, torch.dtype]] = None,
t5xxl_device: Optional[Union[str, torch.device]] = None,
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
vae_dtype: Optional[Union[str, torch.dtype]] = None,
):
"""
Load SD3 models from checkpoint files.
Args:
ckpt_path: Path to the SD3 checkpoint file.
clip_l_path: Path to the clip_l checkpoint file.
clip_g_path: Path to the clip_g checkpoint file.
t5xxl_path: Path to the t5xxl checkpoint file.
vae_path: Path to the VAE checkpoint file.
attn_mode: Attention mode for MMDiT model.
device: Device for MMDiT model.
weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different.
disable_mmap: Disable memory mapping when loading state dict.
clip_dtype: Dtype for Clip models, or None to use default dtype.
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
vae_dtype: Dtype for VAE model, or None to use default dtype.
Returns:
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
"""
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
# Therefore, we need clip_dtype and t5xxl_dtype.
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
if disable_mmap:
return safetensors.torch.load(open(path, "rb").read())
else:
try:
return load_file(path, device=dvc)
except:
return load_file(path) # prevent device invalid Error
t5xxl_device = t5xxl_device or device
clip_dtype = clip_dtype or weight_dtype or torch.float32
t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32
vae_dtype = vae_dtype or weight_dtype or torch.float32
logger.info(f"Loading SD3 models from {ckpt_path}...")
state_dict = load_state_dict(ckpt_path)
# load clip_l
clip_l_sd = None
if clip_l_path:
logger.info(f"Loading clip_l from {clip_l_path}...")
clip_l_sd = load_state_dict(clip_l_path)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
else:
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
# load clip_g
clip_g_sd = None
if clip_g_path:
logger.info(f"Loading clip_g from {clip_g_path}...")
clip_g_sd = load_state_dict(clip_g_path)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
else:
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k in list(state_dict.keys()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
# load t5xxl
t5xxl_sd = None
if t5xxl_path:
logger.info(f"Loading t5xxl from {t5xxl_path}...")
t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k in list(state_dict.keys()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
# MMDiT and VAE
vae_sd = {}
if vae_path:
logger.info(f"Loading VAE from {vae_path}...")
vae_sd = load_state_dict(vae_path)
else:
# remove prefix "first_stage_model."
vae_sd = {}
vae_prefix = "first_stage_model."
for k in list(state_dict.keys()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
mmdit_prefix = "model.diffusion_model."
for k in list(state_dict.keys()):
if k.startswith(mmdit_prefix):
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
else:
state_dict.pop(k) # remove other keys
# load MMDiT
logger.info("Building MMDit")
with init_empty_weights():
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
logger.info("Loading state dict...")
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
logger.info(f"Loaded MMDiT: {info}")
# load ClipG and ClipL
if clip_l_sd is None:
clip_l = None
else:
logger.info("Building ClipL")
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded ClipL: {info}")
clip_l.set_attn_mode(attn_mode)
if clip_g_sd is None:
clip_g = None
else:
logger.info("Building ClipG")
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded ClipG: {info}")
clip_g.set_attn_mode(attn_mode)
# load T5XXL
if t5xxl_sd is None:
t5xxl = None
else:
logger.info("Building T5XXL")
t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded T5XXL: {info}")
t5xxl.set_attn_mode(attn_mode)
# load VAE
logger.info("Building VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
vae.to(device=device, dtype=vae_dtype)
return mmdit, clip_l, clip_g, t5xxl, vae
# endregion
# region utils
def get_cond(
prompt: str,
tokenizer: sd3_models.SD3Tokenizer,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
print(t5_tokens)
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
def get_cond_from_tokens(
l_tokens,
g_tokens,
t5_tokens,
clip_l: sd3_models.SDClipModel,
clip_g: sd3_models.SDXLClipG,
t5xxl: Optional[sd3_models.T5XXLModel] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
l_out, l_pooled = clip_l.encode_token_weights(l_tokens)
g_out, g_pooled = clip_g.encode_token_weights(g_tokens)
lg_out = torch.cat([l_out, g_out], dim=-1)
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if device is not None:
lg_out = lg_out.to(device=device)
l_pooled = l_pooled.to(device=device)
g_pooled = g_pooled.to(device=device)
if dtype is not None:
lg_out = lg_out.to(dtype=dtype)
l_pooled = l_pooled.to(dtype=dtype)
g_pooled = g_pooled.to(dtype=dtype)
# t5xxl may be in another device (eg. cpu)
if t5_tokens is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
else:
t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
if device is not None:
t5_out = t5_out.to(device=device)
if dtype is not None:
t5_out = t5_out.to(dtype=dtype)
# return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1)
# used if other sd3 models is available
r"""
def get_sd3_configs(state_dict: Dict):
# Important configuration values can be quickly determined by checking shapes in the source file
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
# prefix = "model.diffusion_model."
prefix = ""
patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2]
depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64
num_patches = state_dict[prefix + "pos_embed"].shape[1]
pos_embed_max_size = round(math.sqrt(num_patches))
adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1]
context_shape = state_dict[prefix + "context_embedder.weight"].shape
context_embedder_config = {
"target": "torch.nn.Linear",
"params": {"in_features": context_shape[1], "out_features": context_shape[0]},
}
return {
"patch_size": patch_size,
"depth": depth,
"num_patches": num_patches,
"pos_embed_max_size": pos_embed_max_size,
"adm_in_channels": adm_in_channels,
"context_embedder": context_embedder_config,
}
def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"):
""
Doesn't load state dict.
""
sd3_configs = get_sd3_configs(state_dict)
mmdit = sd3_models.MMDiT(
input_size=None,
pos_embed_max_size=sd3_configs["pos_embed_max_size"],
patch_size=sd3_configs["patch_size"],
in_channels=16,
adm_in_channels=sd3_configs["adm_in_channels"],
depth=sd3_configs["depth"],
mlp_ratio=4,
qk_norm=None,
num_patches=sd3_configs["num_patches"],
context_size=4096,
attn_mode=attn_mode,
)
return mmdit
"""
class ModelSamplingDiscreteFlow:
@@ -509,6 +299,3 @@ class ModelSamplingDiscreteFlow:
# assert max_denoise is False, "max_denoise not implemented"
# max_denoise is always True, I'm not sure why it's there
return sigma * noise + (1.0 - sigma) * latent_image
# endregion

View File

@@ -3,7 +3,7 @@ import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
@@ -48,45 +48,79 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None:
"""
Args:
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
"""
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
) -> List[torch.Tensor]:
"""
returned embeddings are not masked
"""
clip_l, clip_g, t5xxl = models
clip_l: CLIPTextModel
clip_g: CLIPTextModelWithProjection
t5xxl: T5EncoderModel
if apply_lg_attn_mask is None:
apply_lg_attn_mask = self.apply_lg_attn_mask
if apply_t5_attn_mask is None:
apply_t5_attn_mask = self.apply_t5_attn_mask
l_tokens, g_tokens, t5_tokens = tokens[:3]
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
if len(tokens) > 3:
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:]
if not apply_lg_attn_mask:
l_attn_mask = None
g_attn_mask = None
else:
l_attn_mask = l_attn_mask.to(clip_l.device)
g_attn_mask = g_attn_mask.to(clip_g.device)
if not apply_t5_attn_mask:
t5_attn_mask = None
else:
t5_attn_mask = t5_attn_mask.to(t5xxl.device)
else:
l_attn_mask = None
g_attn_mask = None
t5_attn_mask = None
if l_tokens is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
lg_pooled = None
else:
with torch.no_grad():
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
if apply_lg_attn_mask:
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
l_pooled = prompt_embeds[0]
l_out = prompt_embeds.hidden_states[-2]
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
g_pooled = prompt_embeds[0]
g_out = prompt_embeds.hidden_states[-2]
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
if apply_t5_attn_mask:
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
with torch.no_grad():
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
else:
t5_out = None
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
return [lg_out, t5_out, lg_pooled]
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
@@ -132,39 +166,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
# t5xxl is optional
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
l_out = lg_out[..., :768]
g_out = lg_out[..., 768:] # 1280
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
return np.concatenate([l_out, g_out], axis=-1)
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
return t5_out * np.expand_dims(t5_attn_mask, -1)
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None
t5_out = data["t5_out"]
if self.apply_lg_attn_mask:
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
if self.apply_t5_attn_mask and t5_out is not None:
t5_attn_mask = data["t5_attn_mask"]
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
return [lg_out, t5_out, lg_pooled]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
@@ -174,7 +207,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
)
@@ -182,38 +215,41 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out is not None and t5_out.dtype == torch.bfloat16:
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
if t5_out is not None:
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i] if t5_out is not None else None
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
clip_l_attn_mask=clip_l_attn_mask_i,
clip_g_attn_mask=clip_g_attn_mask_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
**kwargs,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
@@ -246,41 +282,3 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for Sd3TokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = Sd3TokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

View File

@@ -5967,6 +5967,37 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict
def load_prompts(prompt_file: str) -> List[Dict]:
# read prompts
if prompt_file.endswith(".txt"):
with open(prompt_file, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif prompt_file.endswith(".toml"):
with open(prompt_file, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif prompt_file.endswith(".json"):
with open(prompt_file, "r", encoding="utf-8") as f:
prompts = json.load(f)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
from library.train_util import line_to_prompt_dict
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
return prompts
def sample_images_common(
pipe_class,
accelerator: Accelerator,

View File

@@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
import cv2
from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
# region Logging
def add_logging_arguments(parser):
parser.add_argument(
"--console_log_level",
@@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False):
logger.info(msg_init)
# endregion
# region PyTorch utils
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
"""
Convert a string to a torch.dtype
@@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen:
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
def load_safetensors(
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
) -> dict[str, torch.Tensor]:
if disable_mmap:
# return safetensors.torch.load(open(path, "rb").read())
# use experimental loader
# logger.info(f"Loading without mmap (experimental)")
state_dict = {}
with MemoryEfficientSafeOpen(path) as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
return state_dict
else:
try:
state_dict = load_file(path, device=device)
except:
state_dict = load_file(path) # prevent device invalid Error
if dtype is not None:
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(dtype=dtype)
return state_dict
# endregion
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
@@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
# endregion
# TODO make inf_utils.py
# region Gradual Latent hires fix

View File

@@ -12,6 +12,7 @@ import torch
from safetensors.torch import safe_open, load_file
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModelWithProjection, T5EncoderModel
from library.device_utils import init_ipex, get_preferred_device
@@ -25,11 +26,14 @@ import logging
logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_sd3
from library.utils import load_safetensors
def get_noise(seed, latent):
generator = torch.manual_seed(seed)
return torch.randn(latent.size(), dtype=torch.float32, layout=latent.layout, generator=generator, device="cpu").to(latent.dtype)
def get_noise(seed, latent, device="cpu"):
# generator = torch.manual_seed(seed)
generator = torch.Generator(device)
generator.manual_seed(seed)
return torch.randn(latent.size(), dtype=latent.dtype, layout=latent.layout, generator=generator, device=device)
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
@@ -59,7 +63,7 @@ def do_sample(
neg_cond: Tuple[torch.Tensor, torch.Tensor],
mmdit: sd3_models.MMDiT,
steps: int,
guidance_scale: float,
cfg_scale: float,
dtype: torch.dtype,
device: str,
):
@@ -71,7 +75,7 @@ def do_sample(
latent = latent.to(dtype).to(device)
noise = get_noise(seed, latent).to(device)
noise = get_noise(seed, latent, device)
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
@@ -105,7 +109,7 @@ def do_sample(
batched = model_sampling.calculate_denoised(sigma_hat, model_output, x)
pos_out, neg_out = batched.chunk(2)
denoised = neg_out + (pos_out - neg_out) * guidance_scale
denoised = neg_out + (pos_out - neg_out) * cfg_scale
# print(denoised.shape)
# d = to_d(x, sigma_hat, denoised)
@@ -122,230 +126,68 @@ def do_sample(
x = x.to(dtype)
latent = x
scale_factor = 1.5305
shift_factor = 0.0609
# def process_out(self, latent):
# return (latent / self.scale_factor) + self.shift_factor
latent = (latent / scale_factor) + shift_factor
latent = vae.process_out(latent)
return latent
if __name__ == "__main__":
target_height = 1024
target_width = 1024
# steps = 50 # 28 # 50
guidance_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=77, help="t5xxl token length, default: 77")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--output_dir", type=str, default=".")
parser.add_argument("--do_not_use_t5xxl", action="store_true")
parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
# parser.add_argument(
# "--lora_weights",
# type=str,
# nargs="*",
# default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
# )
# parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
sd3_dtype = torch.float32
if args.fp16:
sd3_dtype = torch.float16
elif args.bf16:
sd3_dtype = torch.bfloat16
# TODO test with separated safetenors files for each model
# load state dict
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
state_dict = load_file(args.ckpt_path)
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_g: remove prefix "text_encoders.clip_g."
logger.info("clip_g is included in the checkpoint")
clip_g_sd = {}
prefix = "text_encoders.clip_g."
for k, v in list(state_dict.items()):
if k.startswith(prefix):
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
else:
logger.info(f"Lodaing clip_g from {args.clip_g}...")
clip_g_sd = load_file(args.clip_g)
for key in list(clip_g_sd.keys()):
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
# found clip_l: remove prefix "text_encoders.clip_l."
logger.info("clip_l is included in the checkpoint")
clip_l_sd = {}
prefix = "text_encoders.clip_l."
for k, v in list(state_dict.items()):
if k.startswith(prefix):
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
else:
logger.info(f"Lodaing clip_l from {args.clip_l}...")
clip_l_sd = load_file(args.clip_l)
for key in list(clip_l_sd.keys()):
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
# found t5xxl: remove prefix "text_encoders.t5xxl."
logger.info("t5xxl is included in the checkpoint")
if not args.do_not_use_t5xxl:
t5xxl_sd = {}
prefix = "text_encoders.t5xxl."
for k, v in list(state_dict.items()):
if k.startswith(prefix):
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
else:
logger.info("but not used")
for key in list(state_dict.keys()):
if key.startswith("text_encoders.t5xxl."):
state_dict.pop(key)
t5xxl_sd = None
elif args.t5xxl:
assert not args.do_not_use_t5xxl, "t5xxl is not used but specified"
logger.info(f"Lodaing t5xxl from {args.t5xxl}...")
t5xxl_sd = load_file(args.t5xxl)
for key in list(t5xxl_sd.keys()):
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
else:
logger.info("t5xxl is not used")
t5xxl_sd = None
use_t5xxl = t5xxl_sd is not None
# MMDiT and VAE
vae_sd = {}
vae_prefix = "first_stage_model."
mmdit_prefix = "model.diffusion_model."
for k, v in list(state_dict.items()):
if k.startswith(vae_prefix):
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
elif k.startswith(mmdit_prefix):
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
# load tokenizers
logger.info("Loading tokenizers...")
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
# load models
# logger.info("Create MMDiT from SD3 checkpoint...")
# mmdit = sd3_utils.create_mmdit_from_sd3_checkpoint(state_dict)
logger.info("Create MMDiT")
mmdit = sd3_models.create_mmdit_sd3_medium_configs(args.attn_mode)
logger.info("Loading state dict...")
info = mmdit.load_state_dict(state_dict)
logger.info(f"Loaded MMDiT: {info}")
logger.info(f"Move MMDiT to {device} and {sd3_dtype}...")
mmdit.to(device, dtype=sd3_dtype)
mmdit.eval()
# load VAE
logger.info("Create VAE")
vae = sd3_models.SDVAE()
logger.info("Loading state dict...")
info = vae.load_state_dict(vae_sd)
logger.info(f"Loaded VAE: {info}")
logger.info(f"Move VAE to {device} and {sd3_dtype}...")
vae.to(device, dtype=sd3_dtype)
vae.eval()
# load text encoders
logger.info("Create clip_l")
clip_l = sd3_models.create_clip_l(device, sd3_dtype, clip_l_sd)
logger.info("Loading state dict...")
info = clip_l.load_state_dict(clip_l_sd)
logger.info(f"Loaded clip_l: {info}")
logger.info(f"Move clip_l to {device} and {sd3_dtype}...")
clip_l.to(device, dtype=sd3_dtype)
clip_l.eval()
logger.info(f"Set attn_mode to {args.attn_mode}...")
clip_l.set_attn_mode(args.attn_mode)
logger.info("Create clip_g")
clip_g = sd3_models.create_clip_g(device, sd3_dtype, clip_g_sd)
logger.info("Loading state dict...")
info = clip_g.load_state_dict(clip_g_sd)
logger.info(f"Loaded clip_g: {info}")
logger.info(f"Move clip_g to {device} and {sd3_dtype}...")
clip_g.to(device, dtype=sd3_dtype)
clip_g.eval()
logger.info(f"Set attn_mode to {args.attn_mode}...")
clip_g.set_attn_mode(args.attn_mode)
if use_t5xxl:
logger.info("Create t5xxl")
t5xxl = sd3_models.create_t5xxl(device, sd3_dtype, t5xxl_sd)
logger.info("Loading state dict...")
info = t5xxl.load_state_dict(t5xxl_sd)
logger.info(f"Loaded t5xxl: {info}")
logger.info(f"Move t5xxl to {device} and {sd3_dtype}...")
t5xxl.to(device, dtype=sd3_dtype)
# t5xxl.to("cpu", dtype=torch.float32) # run on CPU
t5xxl.eval()
logger.info(f"Set attn_mode to {args.attn_mode}...")
t5xxl.set_attn_mode(args.attn_mode)
else:
t5xxl = None
def generate_image(
mmdit: sd3_models.MMDiT,
vae: sd3_models.SDVAE,
clip_l: CLIPTextModelWithProjection,
clip_g: CLIPTextModelWithProjection,
t5xxl: T5EncoderModel,
steps: int,
prompt: str,
seed: int,
target_width: int,
target_height: int,
device: str,
negative_prompt: str,
cfg_scale: float,
):
# prepare embeddings
logger.info("Encoding prompts...")
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
tokens_and_masks = tokenize_strategy.tokenize(args.prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
# TODO support one-by-one offloading
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
with torch.no_grad():
tokens_and_masks = tokenize_strategy.tokenize(prompt)
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
tokens_and_masks = tokenize_strategy.tokenize(args.negative_prompt)
lg_out, t5_out, pooled = encoding_strategy.encode_tokens(
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
lg_out, t5_out, pooled, neg_l_attn_mask, neg_g_attn_mask, neg_t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, clip_g, t5xxl], tokens_and_masks, args.apply_lg_attn_mask, args.apply_t5_attn_mask
)
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
# attn masks are not used currently
if args.offload:
clip_l.to("cpu")
clip_g.to("cpu")
t5xxl.to("cpu")
# generate image
logger.info("Generating image...")
latent_sampled = do_sample(
target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, guidance_scale, sd3_dtype, device
)
mmdit.to(device)
latent_sampled = do_sample(target_height, target_width, None, seed, cond, neg_cond, mmdit, steps, cfg_scale, sd3_dtype, device)
if args.offload:
mmdit.to("cpu")
# latent to image
vae.to(device)
with torch.no_grad():
image = vae.decode(latent_sampled)
if args.offload:
vae.to("cpu")
image = image.float()
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0]
decoded_np = 255.0 * np.moveaxis(image.cpu().numpy(), 0, 2)
@@ -359,3 +201,179 @@ if __name__ == "__main__":
out_image.save(output_path)
logger.info(f"Saved image to {output_path}")
if __name__ == "__main__":
target_height = 1024
target_width = 1024
# steps = 50 # 28 # 50
# cfg_scale = 5
# seed = 1 # None # 1
device = get_preferred_device()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--clip_g", type=str, required=False)
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--t5xxl_token_length", type=int, default=256, help="t5xxl token length, default: 256")
parser.add_argument("--apply_lg_attn_mask", action="store_true")
parser.add_argument("--apply_t5_attn_mask", action="store_true")
parser.add_argument("--prompt", type=str, default="A photo of a cat")
# parser.add_argument("--prompt2", type=str, default=None) # do not support different prompts for text encoders
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--cfg_scale", type=float, default=5.0)
parser.add_argument("--offload", action="store_true", help="Offload to CPU")
parser.add_argument("--output_dir", type=str, default=".")
# parser.add_argument("--do_not_use_t5xxl", action="store_true")
# parser.add_argument("--attn_mode", type=str, default="torch", help="torch (SDPA) or xformers. default: torch")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--steps", type=int, default=50)
# parser.add_argument(
# "--lora_weights",
# type=str,
# nargs="*",
# default=[],
# help="LoRA weights, only supports networks.lora, each argument is a `path;multiplier` (semi-colon separated)",
# )
parser.add_argument("--width", type=int, default=target_width)
parser.add_argument("--height", type=int, default=target_height)
parser.add_argument("--interactive", action="store_true")
args = parser.parse_args()
seed = args.seed
steps = args.steps
sd3_dtype = torch.float32
if args.fp16:
sd3_dtype = torch.float16
elif args.bf16:
sd3_dtype = torch.bfloat16
loading_device = "cpu" if args.offload else device
# load state dict
logger.info(f"Loading SD3 models from {args.ckpt_path}...")
# state_dict = load_file(args.ckpt_path)
state_dict = load_safetensors(args.ckpt_path, loading_device, disable_mmap=True, dtype=sd3_dtype)
# load text encoders
clip_l = sd3_utils.load_clip_l(args.clip_l, sd3_dtype, loading_device, state_dict=state_dict)
clip_g = sd3_utils.load_clip_g(args.clip_g, sd3_dtype, loading_device, state_dict=state_dict)
t5xxl = sd3_utils.load_t5xxl(args.t5xxl, sd3_dtype, loading_device, state_dict=state_dict)
# MMDiT and VAE
vae = sd3_utils.load_vae(None, sd3_dtype, loading_device, state_dict=state_dict)
mmdit = sd3_utils.load_mmdit(state_dict, sd3_dtype, loading_device)
clip_l.to(sd3_dtype)
clip_g.to(sd3_dtype)
t5xxl.to(sd3_dtype)
vae.to(sd3_dtype)
mmdit.to(sd3_dtype)
if not args.offload:
# make sure to move to the device: some tensors are created in the constructor on the CPU
clip_l.to(device)
clip_g.to(device)
t5xxl.to(device)
vae.to(device)
mmdit.to(device)
clip_l.eval()
clip_g.eval()
t5xxl.eval()
mmdit.eval()
vae.eval()
# load tokenizers
logger.info("Loading tokenizers...")
tokenize_strategy = strategy_sd3.Sd3TokenizeStrategy(args.t5xxl_token_length)
encoding_strategy = strategy_sd3.Sd3TextEncodingStrategy()
if not args.interactive:
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
args.steps,
args.prompt,
args.seed,
args.width,
args.height,
device,
args.negative_prompt,
args.cfg_scale,
)
else:
# loop for interactive
width = args.width
height = args.height
steps = None
cfg_scale = args.cfg_scale
while True:
print(
"Enter prompt (empty to exit). Options: --w <width> --h <height> --s <steps> --d <seed>"
" --n <negative prompt>, `--n -` for empty negative prompt"
"Options are kept for the next prompt. Current options:"
f" width={width}, height={height}, steps={steps}, seed={seed}, cfg_scale={cfg_scale}"
)
prompt = input()
if prompt == "":
break
# parse options
options = prompt.split("--")
prompt = options[0].strip()
seed = None
negative_prompt = None
for opt in options[1:]:
try:
opt = opt.strip()
if opt.startswith("w"):
width = int(opt[1:].strip())
elif opt.startswith("h"):
height = int(opt[1:].strip())
elif opt.startswith("s"):
steps = int(opt[1:].strip())
elif opt.startswith("d"):
seed = int(opt[1:].strip())
# elif opt.startswith("m"):
# mutipliers = opt[1:].strip().split(",")
# if len(mutipliers) != len(lora_models):
# logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
# continue
# for i, lora_model in enumerate(lora_models):
# lora_model.set_multiplier(float(mutipliers[i]))
elif opt.startswith("n"):
negative_prompt = opt[1:].strip()
if negative_prompt == "-":
negative_prompt = ""
elif opt.startswith("c"):
cfg_scale = float(opt[1:].strip())
except ValueError as e:
logger.error(f"Invalid option: {opt}, {e}")
generate_image(
mmdit,
vae,
clip_l,
clip_g,
t5xxl,
steps if steps is not None else args.steps,
prompt,
seed if seed is not None else args.seed,
width,
height,
device,
negative_prompt if negative_prompt is not None else args.negative_prompt,
cfg_scale,
)
logger.info("Done!")

File diff suppressed because it is too large Load Diff