mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: support Neta Lumina all-in-one weights
This commit is contained in:
@@ -44,10 +44,21 @@ def load_lumina_model(
|
||||
"""
|
||||
logger.info("Building Lumina")
|
||||
with torch.device("meta"):
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
|
||||
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
|
||||
dtype
|
||||
)
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# Neta-Lumina support
|
||||
if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
|
||||
# remove "model.diffusion_model." prefix
|
||||
filtered_state_dict = {
|
||||
k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
|
||||
}
|
||||
state_dict = filtered_state_dict
|
||||
|
||||
info = model.load_state_dict(state_dict, strict=False, assign=True)
|
||||
logger.info(f"Loaded Lumina: {info}")
|
||||
return model
|
||||
@@ -78,6 +89,13 @@ def load_ae(
|
||||
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
|
||||
|
||||
# Neta-Lumina support
|
||||
if "vae.decoder.conv_in.bias" in sd:
|
||||
# remove "vae." prefix
|
||||
filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
|
||||
sd = filtered_sd
|
||||
|
||||
info = ae.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded AE: {info}")
|
||||
return ae
|
||||
@@ -152,6 +170,16 @@ def load_gemma2(
|
||||
break # the model doesn't have annoying prefix
|
||||
sd[new_key] = sd.pop(key)
|
||||
|
||||
# Neta-Lumina support
|
||||
if "text_encoders.gemma2_2b.logit_scale" in sd:
|
||||
# remove "text_encoders.gemma2_2b.transformer.model." prefix
|
||||
filtered_sd = {
|
||||
k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
|
||||
for k, v in sd.items()
|
||||
if k.startswith("text_encoders.gemma2_2b.transformer.model.")
|
||||
}
|
||||
sd = filtered_sd
|
||||
|
||||
info = gemma2.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
@@ -173,7 +201,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
|
||||
# Embedding layers
|
||||
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
|
||||
@@ -211,11 +238,11 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
|
||||
|
||||
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
|
||||
# Handle block-specific patterns
|
||||
if '().' in diff_key:
|
||||
if "()." in diff_key:
|
||||
for block_idx in range(num_double_blocks):
|
||||
block_alpha_key = alpha_key.replace('().', f'{block_idx}.')
|
||||
block_diff_key = diff_key.replace('().', f'{block_idx}.')
|
||||
|
||||
block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
|
||||
block_diff_key = diff_key.replace("().", f"{block_idx}.")
|
||||
|
||||
# Search for and convert block-specific keys
|
||||
for input_key, value in list(sd.items()):
|
||||
if input_key == block_diff_key:
|
||||
@@ -228,6 +255,5 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
|
||||
else:
|
||||
print(f"Not found: {diff_key}")
|
||||
|
||||
|
||||
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
|
||||
return new_sd
|
||||
|
||||
@@ -231,13 +231,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
"--cfg_trunc_ratio",
|
||||
type=float,
|
||||
default=0.25,
|
||||
help="TBD",
|
||||
help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the last 25% of timesteps will be guided.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--renorm_cfg",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="TBD",
|
||||
help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_flash_attn",
|
||||
|
||||
Reference in New Issue
Block a user