mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor SD3 CLIP to transformers etc.
This commit is contained in:
@@ -15,7 +15,6 @@ from PIL import Image
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from library import flux_models, flux_utils, strategy_base, train_util
|
||||
from library.sd3_train_utils import load_prompts
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -70,7 +69,7 @@ def sample_images(
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = load_prompts(args.sample_prompts)
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -10,40 +10,21 @@ from safetensors import safe_open
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import CLIPTextModel, CLIPConfig, T5EncoderModel, T5Config
|
||||
|
||||
from library import flux_models
|
||||
|
||||
from library.utils import setup_logging, MemoryEfficientSafeOpen
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models
|
||||
from library.utils import load_safetensors
|
||||
|
||||
MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
|
||||
|
||||
# temporary copy from sd3_utils TODO refactor
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
):
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=device)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
チェックポイントの状態を分析し、DiffusersかBFLか、devかschnellか、ブロック数を計算して返す。
|
||||
@@ -161,8 +142,14 @@ def load_ae(
|
||||
return ae
|
||||
|
||||
|
||||
def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False) -> CLIPTextModel:
|
||||
logger.info("Building CLIP")
|
||||
def load_clip_l(
|
||||
ckpt_path: Optional[str],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> CLIPTextModel:
|
||||
logger.info("Building CLIP-L")
|
||||
CLIPL_CONFIG = {
|
||||
"_name_or_path": "clip-vit-large-patch14/",
|
||||
"architectures": ["CLIPModel"],
|
||||
@@ -255,15 +242,22 @@ def load_clip_l(ckpt_path: str, dtype: torch.dtype, device: Union[str, torch.dev
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModel._from_config(config)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = clip.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP: {info}")
|
||||
logger.info(f"Loaded CLIP-L: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(
|
||||
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
|
||||
ckpt_path: str,
|
||||
dtype: Optional[torch.dtype],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[dict] = None,
|
||||
) -> T5EncoderModel:
|
||||
T5_CONFIG_JSON = """
|
||||
{
|
||||
@@ -303,8 +297,11 @@ def load_t5xxl(
|
||||
with init_empty_weights():
|
||||
t5xxl = T5EncoderModel._from_config(config)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
if state_dict is not None:
|
||||
sd = state_dict
|
||||
else:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = t5xxl.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded T5xxl: {info}")
|
||||
return t5xxl
|
||||
|
||||
@@ -57,8 +57,8 @@ ARCH_SD_V1 = "stable-diffusion-v1"
|
||||
ARCH_SD_V2_512 = "stable-diffusion-v2-512"
|
||||
ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v"
|
||||
ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
ARCH_SD3_M = "stable-diffusion-3-medium"
|
||||
ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
|
||||
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
|
||||
ARCH_FLUX_1_DEV = "flux-1-dev"
|
||||
ARCH_FLUX_1_UNKNOWN = "flux-1"
|
||||
|
||||
@@ -140,10 +140,7 @@ def build_metadata(
|
||||
if sdxl:
|
||||
arch = ARCH_SD_XL_V1_BASE
|
||||
elif sd3 is not None:
|
||||
if sd3 == "m":
|
||||
arch = ARCH_SD3_M
|
||||
else:
|
||||
arch = ARCH_SD3_UNKNOWN
|
||||
arch = ARCH_SD3_M + "-" + sd3
|
||||
elif flux is not None:
|
||||
if flux == "dev":
|
||||
arch = ARCH_FLUX_1_DEV
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,8 +11,8 @@ from safetensors.torch import save_file
|
||||
from accelerate import Accelerator, PartialState
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -28,60 +28,16 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .sdxl_train_util import match_mixed_precision
|
||||
|
||||
|
||||
def load_target_model(
|
||||
model_type: str,
|
||||
args: argparse.Namespace,
|
||||
state_dict: dict,
|
||||
accelerator: Accelerator,
|
||||
attn_mode: str,
|
||||
model_dtype: Optional[torch.dtype],
|
||||
device: Optional[torch.device],
|
||||
) -> Union[
|
||||
sd3_models.MMDiT,
|
||||
Optional[sd3_models.SDClipModel],
|
||||
Optional[sd3_models.SDXLClipG],
|
||||
Optional[sd3_models.T5XXLModel],
|
||||
sd3_models.SDVAE,
|
||||
]:
|
||||
loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
|
||||
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
if model_type == "mmdit":
|
||||
model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "clip_l":
|
||||
model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "clip_g":
|
||||
model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "t5xxl":
|
||||
model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "vae":
|
||||
model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
|
||||
if args.lowram:
|
||||
model = model.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return model
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util
|
||||
|
||||
|
||||
def save_models(
|
||||
ckpt_path: str,
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel],
|
||||
mmdit: Optional[sd3_models.MMDiT],
|
||||
vae: Optional[sd3_models.SDVAE],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
sai_metadata: Optional[dict],
|
||||
save_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
@@ -101,24 +57,35 @@ def save_models(
|
||||
update_sd("model.diffusion_model.", mmdit.state_dict())
|
||||
update_sd("first_stage_model.", vae.state_dict())
|
||||
|
||||
if clip_l is not None:
|
||||
update_sd("text_encoders.clip_l.", clip_l.state_dict())
|
||||
if clip_g is not None:
|
||||
update_sd("text_encoders.clip_g.", clip_g.state_dict())
|
||||
if t5xxl is not None:
|
||||
update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
|
||||
# do not support unified checkpoint format for now
|
||||
# if clip_l is not None:
|
||||
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
|
||||
# if clip_g is not None:
|
||||
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
|
||||
# if t5xxl is not None:
|
||||
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
|
||||
|
||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
|
||||
if clip_l is not None:
|
||||
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
|
||||
save_file(clip_l.state_dict(), clip_l_path)
|
||||
if clip_g is not None:
|
||||
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
|
||||
save_file(clip_g.state_dict(), clip_g_path)
|
||||
if t5xxl is not None:
|
||||
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
||||
save_file(t5xxl.state_dict(), t5xxl_path)
|
||||
|
||||
|
||||
def save_sd3_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
):
|
||||
@@ -141,9 +108,9 @@ def save_sd3_model_on_epoch_end_or_stepwise(
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
):
|
||||
@@ -208,23 +175,27 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する"
|
||||
"--save_clip",
|
||||
action="store_true",
|
||||
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する"
|
||||
"--save_t5xxl",
|
||||
action="store_true",
|
||||
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t5xxl_device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
|
||||
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_dtype",
|
||||
type=str,
|
||||
default=None,
|
||||
help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
|
||||
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
@@ -233,16 +204,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
default="logit_normal",
|
||||
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
||||
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||
"--logit_mean",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_std",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd",
|
||||
)
|
||||
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
||||
parser.add_argument(
|
||||
"--mode_scale",
|
||||
type=float,
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
|
||||
)
|
||||
|
||||
|
||||
@@ -283,7 +263,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
# temporary copied from sd3_minimal_inferece.py
|
||||
|
||||
|
||||
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
@@ -327,7 +307,7 @@ def do_sample(
|
||||
|
||||
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
|
||||
|
||||
sigmas = get_sigmas(model_sampling, steps).to(device)
|
||||
sigmas = get_all_sigmas(model_sampling, steps).to(device)
|
||||
|
||||
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
|
||||
|
||||
@@ -371,37 +351,6 @@ def do_sample(
|
||||
return x
|
||||
|
||||
|
||||
def load_prompts(prompt_file: str) -> List[Dict]:
|
||||
# read prompts
|
||||
if prompt_file.endswith(".txt"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif prompt_file.endswith(".toml"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif prompt_file.endswith(".json"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
from library.train_util import line_to_prompt_dict
|
||||
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
@@ -440,7 +389,7 @@ def sample_images(
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = load_prompts(args.sample_prompts)
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -510,7 +459,7 @@ def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
mmdit: sd3_models.MMDiT,
|
||||
text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]],
|
||||
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
|
||||
vae: sd3_models.SDVAE,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
@@ -568,7 +517,7 @@ def sample_image_inference(
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt)
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
|
||||
|
||||
lg_out, t5_out, pooled = te_outputs
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# encode negative prompts
|
||||
@@ -578,7 +527,7 @@ def sample_image_inference(
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt)
|
||||
neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
|
||||
|
||||
lg_out, t5_out, pooled = neg_te_outputs
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# sample image
|
||||
@@ -609,14 +558,9 @@ def sample_image_inference(
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log(
|
||||
{f"sample_{i}": wandb.Image(
|
||||
image,
|
||||
caption=prompt # positive prompt as a caption
|
||||
)},
|
||||
commit=False
|
||||
)
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
# region Diffusers
|
||||
@@ -886,4 +830,78 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
return noisy_model_input, timesteps, sigmas
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from typing import Dict, Optional, Union
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
import torch
|
||||
import safetensors
|
||||
from safetensors.torch import load_file
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPConfig, CLIPTextConfig
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
@@ -19,18 +22,61 @@ from library import sdxl_model_util
|
||||
|
||||
# region models
|
||||
|
||||
# TODO remove dependency on flux_utils
|
||||
from library.utils import load_safetensors
|
||||
from library.flux_utils import load_t5xxl as flux_utils_load_t5xxl
|
||||
|
||||
def load_safetensors(path: str, dvc: Union[str, torch.device], disable_mmap: bool = False):
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(path, "rb").read())
|
||||
|
||||
def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
|
||||
logger.info(f"Analyzing state dict state...")
|
||||
|
||||
# analyze configs
|
||||
patch_size = state_dict[f"{prefix}x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[f"{prefix}x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[f"{prefix}pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[f"{prefix}y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[f"{prefix}context_embedder.weight"].shape
|
||||
qk_norm = "rms" if f"{prefix}joint_blocks.0.context_block.attn.ln_k.weight" in state_dict.keys() else None
|
||||
|
||||
# x_block_self_attn_layers.append(int(key.split(".x_block.attn2.ln_k.weight")[0].split(".")[-1]))
|
||||
x_block_self_attn_layers = []
|
||||
re_attn = re.compile(r".(\d+).x_block.attn2.ln_k.weight")
|
||||
for key in list(state_dict.keys()):
|
||||
m = re_attn.match(key)
|
||||
if m:
|
||||
x_block_self_attn_layers.append(int(m.group(1)))
|
||||
|
||||
assert len(x_block_self_attn_layers) == 0, "x_block_self_attn_layers is not supported"
|
||||
|
||||
context_embedder_in_features = context_shape[1]
|
||||
context_embedder_out_features = context_shape[0]
|
||||
|
||||
# only supports 3-5-large and 3-medium
|
||||
if qk_norm is not None:
|
||||
model_type = "3-5-large"
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=dvc)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
model_type = "3-medium"
|
||||
|
||||
params = sd3_models.SD3Params(
|
||||
patch_size=patch_size,
|
||||
depth=depth,
|
||||
num_patches=num_patches,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
adm_in_channels=adm_in_channels,
|
||||
qk_norm=qk_norm,
|
||||
x_block_self_attn_layers=x_block_self_attn_layers,
|
||||
context_embedder_in_features=context_embedder_in_features,
|
||||
context_embedder_out_features=context_embedder_out_features,
|
||||
model_type=model_type,
|
||||
)
|
||||
logger.info(f"Analyzed state dict state: {params}")
|
||||
return params
|
||||
|
||||
|
||||
def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device]):
|
||||
def load_mmdit(
|
||||
state_dict: Dict, dtype: Optional[Union[str, torch.dtype]], device: Union[str, torch.device], attn_mode: str = "torch"
|
||||
) -> sd3_models.MMDiT:
|
||||
mmdit_sd = {}
|
||||
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
@@ -40,8 +86,9 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
|
||||
|
||||
# load MMDiT
|
||||
logger.info("Building MMDit")
|
||||
params = analyze_state_dict_state(mmdit_sd)
|
||||
with init_empty_weights():
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||
mmdit = sd3_models.create_sd3_mmdit(params, attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, mmdit_sd, device, dtype)
|
||||
@@ -50,20 +97,14 @@ def load_mmdit(state_dict: Dict, attn_mode: str, dtype: Optional[Union[str, torc
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
state_dict: Dict,
|
||||
clip_l_path: Optional[str],
|
||||
attn_mode: str,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
clip_l_sd = None
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading clip_l from {clip_l_path}...")
|
||||
clip_l_sd = load_safetensors(clip_l_path, device, disable_mmap)
|
||||
for key in list(clip_l_sd.keys()):
|
||||
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||
else:
|
||||
if clip_l_path is None:
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
@@ -72,34 +113,58 @@ def load_clip_l(
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif clip_l_path is None:
|
||||
logger.info("clip_l is not included in the checkpoint and clip_l_path is not provided")
|
||||
return None
|
||||
|
||||
# load clip_l
|
||||
logger.info("Building CLIP-L")
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=768,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="quick_gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=768,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModelWithProjection(config)
|
||||
|
||||
if clip_l_sd is None:
|
||||
clip_l = None
|
||||
else:
|
||||
logger.info("Building ClipL")
|
||||
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_l.load_state_dict(clip_l_sd)
|
||||
logger.info(f"Loaded ClipL: {info}")
|
||||
clip_l.set_attn_mode(attn_mode)
|
||||
return clip_l
|
||||
logger.info(f"Loading state dict from {clip_l_path}")
|
||||
clip_l_sd = load_safetensors(clip_l_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
if "text_projection.weight" not in clip_l_sd:
|
||||
logger.info("Adding text_projection.weight to clip_l_sd")
|
||||
clip_l_sd["text_projection.weight"] = torch.eye(768, dtype=dtype, device=device)
|
||||
|
||||
info = clip.load_state_dict(clip_l_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-L: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_clip_g(
|
||||
state_dict: Dict,
|
||||
clip_g_path: Optional[str],
|
||||
attn_mode: str,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
clip_g_sd = None
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading clip_g from {clip_g_path}...")
|
||||
clip_g_sd = load_safetensors(clip_g_path, device, disable_mmap)
|
||||
for key in list(clip_g_sd.keys()):
|
||||
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||
else:
|
||||
if state_dict is not None:
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
@@ -108,34 +173,53 @@ def load_clip_g(
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif clip_g_path is None:
|
||||
logger.info("clip_g is not included in the checkpoint and clip_g_path is not provided")
|
||||
return None
|
||||
|
||||
# load clip_g
|
||||
logger.info("Building CLIP-G")
|
||||
config = CLIPTextConfig(
|
||||
vocab_size=49408,
|
||||
hidden_size=1280,
|
||||
intermediate_size=5120,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=20,
|
||||
max_position_embeddings=77,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-05,
|
||||
dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=1.0,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
model_type="clip_text_model",
|
||||
projection_dim=1280,
|
||||
# torch_dtype="float32",
|
||||
# transformers_version="4.25.0.dev0",
|
||||
)
|
||||
with init_empty_weights():
|
||||
clip = CLIPTextModelWithProjection(config)
|
||||
|
||||
if clip_g_sd is None:
|
||||
clip_g = None
|
||||
else:
|
||||
logger.info("Building ClipG")
|
||||
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_g.load_state_dict(clip_g_sd)
|
||||
logger.info(f"Loaded ClipG: {info}")
|
||||
clip_g.set_attn_mode(attn_mode)
|
||||
return clip_g
|
||||
logger.info(f"Loading state dict from {clip_g_path}")
|
||||
clip_g_sd = load_safetensors(clip_g_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = clip.load_state_dict(clip_g_sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded CLIP-G: {info}")
|
||||
return clip
|
||||
|
||||
|
||||
def load_t5xxl(
|
||||
state_dict: Dict,
|
||||
t5xxl_path: Optional[str],
|
||||
attn_mode: str,
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Union[str, torch.device],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
t5xxl_sd = None
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading t5xxl from {t5xxl_path}...")
|
||||
t5xxl_sd = load_safetensors(t5xxl_path, device, disable_mmap)
|
||||
for key in list(t5xxl_sd.keys()):
|
||||
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||
else:
|
||||
if state_dict is not None:
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
@@ -144,29 +228,19 @@ def load_t5xxl(
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
elif t5xxl_path is None:
|
||||
logger.info("t5xxl is not included in the checkpoint and t5xxl_path is not provided")
|
||||
return None
|
||||
|
||||
if t5xxl_sd is None:
|
||||
t5xxl = None
|
||||
else:
|
||||
logger.info("Building T5XXL")
|
||||
|
||||
# workaround for T5XXL model creation: create with fp16 takes too long TODO support virtual device
|
||||
t5xxl = sd3_models.create_t5xxl(device, torch.float32, t5xxl_sd)
|
||||
t5xxl.to(dtype=dtype)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||
logger.info(f"Loaded T5XXL: {info}")
|
||||
t5xxl.set_attn_mode(attn_mode)
|
||||
return t5xxl
|
||||
return flux_utils_load_t5xxl(t5xxl_path, dtype, device, disable_mmap, state_dict=t5xxl_sd)
|
||||
|
||||
|
||||
def load_vae(
|
||||
state_dict: Dict,
|
||||
vae_path: Optional[str],
|
||||
vae_dtype: Optional[Union[str, torch.dtype]],
|
||||
device: Optional[Union[str, torch.device]],
|
||||
disable_mmap: bool = False,
|
||||
state_dict: Optional[Dict] = None,
|
||||
):
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
@@ -181,299 +255,15 @@ def load_vae(
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
logger.info("Building VAE")
|
||||
vae = sd3_models.SDVAE()
|
||||
vae = sd3_models.SDVAE(vae_dtype, device)
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
vae.to(device=device, dtype=vae_dtype)
|
||||
vae.to(device=device, dtype=vae_dtype) # make sure it's in the right device and dtype
|
||||
return vae
|
||||
|
||||
|
||||
def load_models(
|
||||
ckpt_path: str,
|
||||
clip_l_path: str,
|
||||
clip_g_path: str,
|
||||
t5xxl_path: str,
|
||||
vae_path: str,
|
||||
attn_mode: str,
|
||||
device: Union[str, torch.device],
|
||||
weight_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
disable_mmap: bool = False,
|
||||
clip_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
t5xxl_device: Optional[Union[str, torch.device]] = None,
|
||||
t5xxl_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
vae_dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
):
|
||||
"""
|
||||
Load SD3 models from checkpoint files.
|
||||
|
||||
Args:
|
||||
ckpt_path: Path to the SD3 checkpoint file.
|
||||
clip_l_path: Path to the clip_l checkpoint file.
|
||||
clip_g_path: Path to the clip_g checkpoint file.
|
||||
t5xxl_path: Path to the t5xxl checkpoint file.
|
||||
vae_path: Path to the VAE checkpoint file.
|
||||
attn_mode: Attention mode for MMDiT model.
|
||||
device: Device for MMDiT model.
|
||||
weight_dtype: Default dtype of weights for all models. This is weight dtype, so the model dtype may be different.
|
||||
disable_mmap: Disable memory mapping when loading state dict.
|
||||
clip_dtype: Dtype for Clip models, or None to use default dtype.
|
||||
t5xxl_device: Device for T5XXL model to load T5XXL in another device (eg. gpu). Default is None to use device.
|
||||
t5xxl_dtype: Dtype for T5XXL model, or None to use default dtype.
|
||||
vae_dtype: Dtype for VAE model, or None to use default dtype.
|
||||
|
||||
Returns:
|
||||
Tuple of MMDiT, ClipL, ClipG, T5XXL, and VAE models.
|
||||
"""
|
||||
|
||||
# In SD1/2 and SDXL, the model is created with empty weights and then loaded with state dict.
|
||||
# However, in SD3, Clip and T5XXL models are created with dtype, so we need to set dtype before loading state dict.
|
||||
# Therefore, we need clip_dtype and t5xxl_dtype.
|
||||
|
||||
def load_state_dict(path: str, dvc: Union[str, torch.device] = device):
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(path, "rb").read())
|
||||
else:
|
||||
try:
|
||||
return load_file(path, device=dvc)
|
||||
except:
|
||||
return load_file(path) # prevent device invalid Error
|
||||
|
||||
t5xxl_device = t5xxl_device or device
|
||||
clip_dtype = clip_dtype or weight_dtype or torch.float32
|
||||
t5xxl_dtype = t5xxl_dtype or weight_dtype or torch.float32
|
||||
vae_dtype = vae_dtype or weight_dtype or torch.float32
|
||||
|
||||
logger.info(f"Loading SD3 models from {ckpt_path}...")
|
||||
state_dict = load_state_dict(ckpt_path)
|
||||
|
||||
# load clip_l
|
||||
clip_l_sd = None
|
||||
if clip_l_path:
|
||||
logger.info(f"Loading clip_l from {clip_l_path}...")
|
||||
clip_l_sd = load_state_dict(clip_l_path)
|
||||
for key in list(clip_l_sd.keys()):
|
||||
clip_l_sd["transformer." + key] = clip_l_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_l: remove prefix "text_encoders.clip_l."
|
||||
logger.info("clip_l is included in the checkpoint")
|
||||
clip_l_sd = {}
|
||||
prefix = "text_encoders.clip_l."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_l_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load clip_g
|
||||
clip_g_sd = None
|
||||
if clip_g_path:
|
||||
logger.info(f"Loading clip_g from {clip_g_path}...")
|
||||
clip_g_sd = load_state_dict(clip_g_path)
|
||||
for key in list(clip_g_sd.keys()):
|
||||
clip_g_sd["transformer." + key] = clip_g_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight" in state_dict:
|
||||
# found clip_g: remove prefix "text_encoders.clip_g."
|
||||
logger.info("clip_g is included in the checkpoint")
|
||||
clip_g_sd = {}
|
||||
prefix = "text_encoders.clip_g."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
clip_g_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# load t5xxl
|
||||
t5xxl_sd = None
|
||||
if t5xxl_path:
|
||||
logger.info(f"Loading t5xxl from {t5xxl_path}...")
|
||||
t5xxl_sd = load_state_dict(t5xxl_path, t5xxl_device)
|
||||
for key in list(t5xxl_sd.keys()):
|
||||
t5xxl_sd["transformer." + key] = t5xxl_sd.pop(key)
|
||||
else:
|
||||
if "text_encoders.t5xxl.transformer.encoder.block.0.layer.0.SelfAttention.k.weight" in state_dict:
|
||||
# found t5xxl: remove prefix "text_encoders.t5xxl."
|
||||
logger.info("t5xxl is included in the checkpoint")
|
||||
t5xxl_sd = {}
|
||||
prefix = "text_encoders.t5xxl."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(prefix):
|
||||
t5xxl_sd[k[len(prefix) :]] = state_dict.pop(k)
|
||||
|
||||
# MMDiT and VAE
|
||||
vae_sd = {}
|
||||
if vae_path:
|
||||
logger.info(f"Loading VAE from {vae_path}...")
|
||||
vae_sd = load_state_dict(vae_path)
|
||||
else:
|
||||
# remove prefix "first_stage_model."
|
||||
vae_sd = {}
|
||||
vae_prefix = "first_stage_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(vae_prefix):
|
||||
vae_sd[k[len(vae_prefix) :]] = state_dict.pop(k)
|
||||
|
||||
mmdit_prefix = "model.diffusion_model."
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(mmdit_prefix):
|
||||
state_dict[k[len(mmdit_prefix) :]] = state_dict.pop(k)
|
||||
else:
|
||||
state_dict.pop(k) # remove other keys
|
||||
|
||||
# load MMDiT
|
||||
logger.info("Building MMDit")
|
||||
with init_empty_weights():
|
||||
mmdit = sd3_models.create_mmdit_sd3_medium_configs(attn_mode)
|
||||
|
||||
logger.info("Loading state dict...")
|
||||
info = sdxl_model_util._load_state_dict_on_device(mmdit, state_dict, device, weight_dtype)
|
||||
logger.info(f"Loaded MMDiT: {info}")
|
||||
|
||||
# load ClipG and ClipL
|
||||
if clip_l_sd is None:
|
||||
clip_l = None
|
||||
else:
|
||||
logger.info("Building ClipL")
|
||||
clip_l = sd3_models.create_clip_l(device, clip_dtype, clip_l_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_l.load_state_dict(clip_l_sd)
|
||||
logger.info(f"Loaded ClipL: {info}")
|
||||
clip_l.set_attn_mode(attn_mode)
|
||||
|
||||
if clip_g_sd is None:
|
||||
clip_g = None
|
||||
else:
|
||||
logger.info("Building ClipG")
|
||||
clip_g = sd3_models.create_clip_g(device, clip_dtype, clip_g_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = clip_g.load_state_dict(clip_g_sd)
|
||||
logger.info(f"Loaded ClipG: {info}")
|
||||
clip_g.set_attn_mode(attn_mode)
|
||||
|
||||
# load T5XXL
|
||||
if t5xxl_sd is None:
|
||||
t5xxl = None
|
||||
else:
|
||||
logger.info("Building T5XXL")
|
||||
t5xxl = sd3_models.create_t5xxl(t5xxl_device, t5xxl_dtype, t5xxl_sd)
|
||||
logger.info("Loading state dict...")
|
||||
info = t5xxl.load_state_dict(t5xxl_sd)
|
||||
logger.info(f"Loaded T5XXL: {info}")
|
||||
t5xxl.set_attn_mode(attn_mode)
|
||||
|
||||
# load VAE
|
||||
logger.info("Building VAE")
|
||||
vae = sd3_models.SDVAE()
|
||||
logger.info("Loading state dict...")
|
||||
info = vae.load_state_dict(vae_sd)
|
||||
logger.info(f"Loaded VAE: {info}")
|
||||
vae.to(device=device, dtype=vae_dtype)
|
||||
|
||||
return mmdit, clip_l, clip_g, t5xxl, vae
|
||||
|
||||
|
||||
# endregion
|
||||
# region utils
|
||||
|
||||
|
||||
def get_cond(
|
||||
prompt: str,
|
||||
tokenizer: sd3_models.SD3Tokenizer,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
|
||||
print(t5_tokens)
|
||||
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
|
||||
|
||||
|
||||
def get_cond_from_tokens(
|
||||
l_tokens,
|
||||
g_tokens,
|
||||
t5_tokens,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
l_out, l_pooled = clip_l.encode_token_weights(l_tokens)
|
||||
g_out, g_pooled = clip_g.encode_token_weights(g_tokens)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
if device is not None:
|
||||
lg_out = lg_out.to(device=device)
|
||||
l_pooled = l_pooled.to(device=device)
|
||||
g_pooled = g_pooled.to(device=device)
|
||||
if dtype is not None:
|
||||
lg_out = lg_out.to(dtype=dtype)
|
||||
l_pooled = l_pooled.to(dtype=dtype)
|
||||
g_pooled = g_pooled.to(dtype=dtype)
|
||||
|
||||
# t5xxl may be in another device (eg. cpu)
|
||||
if t5_tokens is None:
|
||||
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
||||
else:
|
||||
t5_out, _ = t5xxl.encode_token_weights(t5_tokens) # t5_out is [1, 77, 4096], t5_pooled is None
|
||||
if device is not None:
|
||||
t5_out = t5_out.to(device=device)
|
||||
if dtype is not None:
|
||||
t5_out = t5_out.to(dtype=dtype)
|
||||
|
||||
# return torch.cat([lg_out, t5_out], dim=-2), torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
return lg_out, t5_out, torch.cat((l_pooled, g_pooled), dim=-1)
|
||||
|
||||
|
||||
# used if other sd3 models is available
|
||||
r"""
|
||||
def get_sd3_configs(state_dict: Dict):
|
||||
# Important configuration values can be quickly determined by checking shapes in the source file
|
||||
# Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change)
|
||||
# prefix = "model.diffusion_model."
|
||||
prefix = ""
|
||||
|
||||
patch_size = state_dict[prefix + "x_embedder.proj.weight"].shape[2]
|
||||
depth = state_dict[prefix + "x_embedder.proj.weight"].shape[0] // 64
|
||||
num_patches = state_dict[prefix + "pos_embed"].shape[1]
|
||||
pos_embed_max_size = round(math.sqrt(num_patches))
|
||||
adm_in_channels = state_dict[prefix + "y_embedder.mlp.0.weight"].shape[1]
|
||||
context_shape = state_dict[prefix + "context_embedder.weight"].shape
|
||||
context_embedder_config = {
|
||||
"target": "torch.nn.Linear",
|
||||
"params": {"in_features": context_shape[1], "out_features": context_shape[0]},
|
||||
}
|
||||
return {
|
||||
"patch_size": patch_size,
|
||||
"depth": depth,
|
||||
"num_patches": num_patches,
|
||||
"pos_embed_max_size": pos_embed_max_size,
|
||||
"adm_in_channels": adm_in_channels,
|
||||
"context_embedder": context_embedder_config,
|
||||
}
|
||||
|
||||
|
||||
def create_mmdit_from_sd3_checkpoint(state_dict: Dict, attn_mode: str = "xformers"):
|
||||
""
|
||||
Doesn't load state dict.
|
||||
""
|
||||
sd3_configs = get_sd3_configs(state_dict)
|
||||
|
||||
mmdit = sd3_models.MMDiT(
|
||||
input_size=None,
|
||||
pos_embed_max_size=sd3_configs["pos_embed_max_size"],
|
||||
patch_size=sd3_configs["patch_size"],
|
||||
in_channels=16,
|
||||
adm_in_channels=sd3_configs["adm_in_channels"],
|
||||
depth=sd3_configs["depth"],
|
||||
mlp_ratio=4,
|
||||
qk_norm=None,
|
||||
num_patches=sd3_configs["num_patches"],
|
||||
context_size=4096,
|
||||
attn_mode=attn_mode,
|
||||
)
|
||||
return mmdit
|
||||
"""
|
||||
|
||||
|
||||
class ModelSamplingDiscreteFlow:
|
||||
@@ -509,6 +299,3 @@ class ModelSamplingDiscreteFlow:
|
||||
# assert max_denoise is False, "max_denoise not implemented"
|
||||
# max_denoise is always True, I'm not sure why it's there
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -3,7 +3,7 @@ import glob
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
@@ -48,45 +48,79 @@ class Sd3TokenizeStrategy(TokenizeStrategy):
|
||||
|
||||
|
||||
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
def __init__(self, apply_lg_attn_mask: Optional[bool] = None, apply_t5_attn_mask: Optional[bool] = None) -> None:
|
||||
"""
|
||||
Args:
|
||||
apply_t5_attn_mask: Default value for apply_t5_attn_mask.
|
||||
"""
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
|
||||
def encode_tokens(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
tokens: List[torch.Tensor],
|
||||
apply_lg_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
apply_lg_attn_mask: Optional[bool] = False,
|
||||
apply_t5_attn_mask: Optional[bool] = False,
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
returned embeddings are not masked
|
||||
"""
|
||||
clip_l, clip_g, t5xxl = models
|
||||
clip_l: CLIPTextModel
|
||||
clip_g: CLIPTextModelWithProjection
|
||||
t5xxl: T5EncoderModel
|
||||
|
||||
if apply_lg_attn_mask is None:
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
if apply_t5_attn_mask is None:
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = tokens[:3]
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:] if len(tokens) > 3 else [None, None, None]
|
||||
|
||||
if len(tokens) > 3:
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = tokens[3:]
|
||||
if not apply_lg_attn_mask:
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
else:
|
||||
l_attn_mask = l_attn_mask.to(clip_l.device)
|
||||
g_attn_mask = g_attn_mask.to(clip_g.device)
|
||||
if not apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
else:
|
||||
t5_attn_mask = t5_attn_mask.to(t5xxl.device)
|
||||
else:
|
||||
l_attn_mask = None
|
||||
g_attn_mask = None
|
||||
t5_attn_mask = None
|
||||
|
||||
if l_tokens is None:
|
||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||
lg_out = None
|
||||
lg_pooled = None
|
||||
else:
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
l_out, l_pooled = clip_l(l_tokens)
|
||||
g_out, g_pooled = clip_g(g_tokens)
|
||||
if apply_lg_attn_mask:
|
||||
l_out = l_out * l_attn_mask.to(l_out.device).unsqueeze(-1)
|
||||
g_out = g_out * g_attn_mask.to(g_out.device).unsqueeze(-1)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
with torch.no_grad():
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
prompt_embeds = clip_l(l_tokens.to(clip_l.device), l_attn_mask, output_hidden_states=True)
|
||||
l_pooled = prompt_embeds[0]
|
||||
l_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
prompt_embeds = clip_g(g_tokens.to(clip_g.device), g_attn_mask, output_hidden_states=True)
|
||||
g_pooled = prompt_embeds[0]
|
||||
g_out = prompt_embeds.hidden_states[-2]
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
|
||||
if t5xxl is not None and t5_tokens is not None:
|
||||
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
|
||||
if apply_t5_attn_mask:
|
||||
t5_out = t5_out * t5_attn_mask.to(t5_out.device).unsqueeze(-1)
|
||||
with torch.no_grad():
|
||||
t5_out, _ = t5xxl(t5_tokens.to(t5xxl.device), t5_attn_mask, return_dict=False, output_hidden_states=True)
|
||||
else:
|
||||
t5_out = None
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
return [lg_out, t5_out, lg_pooled]
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask] # masks are used for attention masking in transformer
|
||||
|
||||
def concat_encodings(
|
||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||
@@ -132,39 +166,38 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
# t5xxl is optional
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def mask_lg_attn(self, lg_out: np.ndarray, l_attn_mask: np.ndarray, g_attn_mask: np.ndarray) -> np.ndarray:
|
||||
l_out = lg_out[..., :768]
|
||||
g_out = lg_out[..., 768:] # 1280
|
||||
l_out = l_out * np.expand_dims(l_attn_mask, -1) # l_out = l_out * l_attn_mask.
|
||||
g_out = g_out * np.expand_dims(g_attn_mask, -1) # g_out = g_out * g_attn_mask.
|
||||
return np.concatenate([l_out, g_out], axis=-1)
|
||||
|
||||
def mask_t5_attn(self, t5_out: np.ndarray, t5_attn_mask: np.ndarray) -> np.ndarray:
|
||||
return t5_out * np.expand_dims(t5_attn_mask, -1)
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"] if "t5_out" in data else None
|
||||
t5_out = data["t5_out"]
|
||||
|
||||
if self.apply_lg_attn_mask:
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
lg_out = self.mask_lg_attn(lg_out, l_attn_mask, g_attn_mask)
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
if self.apply_t5_attn_mask and t5_out is not None:
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
t5_out = self.mask_t5_attn(t5_out, t5_attn_mask)
|
||||
|
||||
return [lg_out, t5_out, lg_pooled]
|
||||
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
@@ -174,7 +207,7 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
lg_out, t5_out, lg_pooled = sd3_text_encoding_strategy.encode_tokens(
|
||||
lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask = sd3_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, tokens_and_masks, self.apply_lg_attn_mask, self.apply_t5_attn_mask
|
||||
)
|
||||
|
||||
@@ -182,38 +215,41 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out is not None and t5_out.dtype == torch.bfloat16:
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
if t5_out is not None:
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i] if t5_out is not None else None
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
clip_l_attn_mask, clip_g_attn_mask, t5_attn_mask = tokens_and_masks[3:6]
|
||||
clip_l_attn_mask_i = clip_l_attn_mask[i].cpu().numpy()
|
||||
clip_g_attn_mask_i = clip_g_attn_mask[i].cpu().numpy()
|
||||
t5_attn_mask_i = t5_attn_mask[i].cpu().numpy() if t5_attn_mask is not None else None # shouldn't be None
|
||||
kwargs = {}
|
||||
if t5_out is not None:
|
||||
kwargs["t5_out"] = t5_out_i
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
clip_l_attn_mask=clip_l_attn_mask_i,
|
||||
clip_g_attn_mask=clip_g_attn_mask_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
**kwargs,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -246,41 +282,3 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test code for Sd3TokenizeStrategy
|
||||
# tokenizer = sd3_models.SD3Tokenizer()
|
||||
strategy = Sd3TokenizeStrategy(256)
|
||||
text = "hello world"
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
# print(l_tokens.shape)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens_2 = strategy.t5xxl(
|
||||
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
print(l_tokens_2)
|
||||
print(g_tokens_2)
|
||||
print(t5_tokens_2)
|
||||
|
||||
# compare
|
||||
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
||||
|
||||
text = ",".join(["hello world! this is long text"] * 50)
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
||||
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
||||
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
||||
|
||||
@@ -5967,6 +5967,37 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
return prompt_dict
|
||||
|
||||
|
||||
def load_prompts(prompt_file: str) -> List[Dict]:
|
||||
# read prompts
|
||||
if prompt_file.endswith(".txt"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif prompt_file.endswith(".toml"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif prompt_file.endswith(".json"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
from library.train_util import line_to_prompt_dict
|
||||
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator: Accelerator,
|
||||
|
||||
@@ -13,12 +13,16 @@ from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncest
|
||||
import cv2
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
def add_logging_arguments(parser):
|
||||
parser.add_argument(
|
||||
"--console_log_level",
|
||||
@@ -85,6 +89,11 @@ def setup_logging(args=None, log_level=None, reset=False):
|
||||
logger.info(msg_init)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region PyTorch utils
|
||||
|
||||
|
||||
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype:
|
||||
"""
|
||||
Convert a string to a torch.dtype
|
||||
@@ -304,6 +313,35 @@ class MemoryEfficientSafeOpen:
|
||||
# return byte_tensor.view(torch.uint8).to(torch.float16).reshape(shape)
|
||||
raise ValueError(f"Unsupported float8 type: {dtype_str} (upgrade PyTorch to support float8 types)")
|
||||
|
||||
def load_safetensors(
|
||||
path: str, device: Union[str, torch.device], disable_mmap: bool = False, dtype: Optional[torch.dtype] = torch.float32
|
||||
) -> dict[str, torch.Tensor]:
|
||||
if disable_mmap:
|
||||
# return safetensors.torch.load(open(path, "rb").read())
|
||||
# use experimental loader
|
||||
# logger.info(f"Loading without mmap (experimental)")
|
||||
state_dict = {}
|
||||
with MemoryEfficientSafeOpen(path) as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key).to(device, dtype=dtype)
|
||||
return state_dict
|
||||
else:
|
||||
try:
|
||||
state_dict = load_file(path, device=device)
|
||||
except:
|
||||
state_dict = load_file(path) # prevent device invalid Error
|
||||
if dtype is not None:
|
||||
for key in state_dict.keys():
|
||||
state_dict[key] = state_dict[key].to(dtype=dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Image utils
|
||||
|
||||
|
||||
def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
|
||||
|
||||
@@ -323,9 +361,9 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
|
||||
return resized_cv2
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# TODO make inf_utils.py
|
||||
|
||||
|
||||
# region Gradual Latent hires fix
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user