Files
Kohya-ss-sd-scripts/library/lumina_util.py

260 lines
9.5 KiB
Python

import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union
import einops
import torch
from accelerate import init_empty_weights
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import Gemma2Config, Gemma2Model
from library.utils import setup_logging
from library import lumina_models, flux_models
from library.safetensors_utils import load_safetensors
import logging
setup_logging()
logger = logging.getLogger(__name__)
MODEL_VERSION_LUMINA_V2 = "lumina2"
def load_lumina_model(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: torch.device,
disable_mmap: bool = False,
use_flash_attn: bool = False,
use_sage_attn: bool = False,
):
"""
Load the Lumina model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (torch.device): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.
Returns:
model (lumina_models.NextDiT): The loaded model.
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
dtype
)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
# Neta-Lumina support
if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
# remove "model.diffusion_model." prefix
filtered_state_dict = {
k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
}
state_dict = filtered_state_dict
info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
return model
def load_ae(
ckpt_path: str,
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> flux_models.AutoEncoder:
"""
Load the AutoEncoder model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
Returns:
ae (flux_models.AutoEncoder): The loaded model.
"""
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
# Neta-Lumina support
if "vae.decoder.conv_in.bias" in sd:
# remove "vae." prefix
filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
sd = filtered_sd
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
def load_gemma2(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> Gemma2Model:
"""
Load the Gemma2 model from the checkpoint path.
Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.
Returns:
gemma2 (Gemma2Model): The loaded model
"""
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2-2b",
"architectures": ["Gemma2Model"],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": 1,
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": True,
"vocab_size": 256000,
}
config = Gemma2Config(**GEMMA2_CONFIG)
with init_empty_weights():
gemma2 = Gemma2Model._from_config(config)
if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
for key in list(sd.keys()):
new_key = key.replace("model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
# Neta-Lumina support
if "text_encoders.gemma2_2b.logit_scale" in sd:
# remove "text_encoders.gemma2_2b.transformer.model." prefix
filtered_sd = {
k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
for k, v in sd.items()
if k.startswith("text_encoders.gemma2_2b.transformer.model.")
}
sd = filtered_sd
info = gemma2.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Gemma2: {info}")
return gemma2
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x
def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
# Embedding layers
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
"time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
"text_embedder.1.bias": "cap_embedder.1.bias",
"patch_embedder.proj.weight": "x_embedder.weight",
"patch_embedder.proj.bias": "x_embedder.bias",
# Attention modulation
"transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
"transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
# Final layers
"final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
"final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
"final_linear.weight": "final_layer.linear.weight",
"final_linear.bias": "final_layer.linear.bias",
# Noise refiner
"single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
"single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
"single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
"single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
# Normalization
"transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
"transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
# FFN
"transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
"transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
"transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
}
def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
new_sd = sd.copy() # Preserve original keys
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
# Handle block-specific patterns
if "()." in diff_key:
for block_idx in range(num_double_blocks):
block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
block_diff_key = diff_key.replace("().", f"{block_idx}.")
# Search for and convert block-specific keys
for input_key, value in list(sd.items()):
if input_key == block_diff_key:
new_sd[block_alpha_key] = value
else:
# Handle static keys
if diff_key in sd:
print(f"Replacing {diff_key} with {alpha_key}")
new_sd[alpha_key] = sd[diff_key]
else:
print(f"Not found: {diff_key}")
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
return new_sd