diff --git a/library/lumina_util.py b/library/lumina_util.py index 452b242f..87853ef6 100644 --- a/library/lumina_util.py +++ b/library/lumina_util.py @@ -44,10 +44,21 @@ def load_lumina_model( """ logger.info("Building Lumina") with torch.device("meta"): - model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype) + model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to( + dtype + ) logger.info(f"Loading state dict from {ckpt_path}") state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + + # Neta-Lumina support + if "model.diffusion_model.cap_embedder.0.weight" in state_dict: + # remove "model.diffusion_model." prefix + filtered_state_dict = { + k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.") + } + state_dict = filtered_state_dict + info = model.load_state_dict(state_dict, strict=False, assign=True) logger.info(f"Loaded Lumina: {info}") return model @@ -78,6 +89,13 @@ def load_ae( logger.info(f"Loading state dict from {ckpt_path}") sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype) + + # Neta-Lumina support + if "vae.decoder.conv_in.bias" in sd: + # remove "vae." prefix + filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")} + sd = filtered_sd + info = ae.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded AE: {info}") return ae @@ -152,6 +170,16 @@ def load_gemma2( break # the model doesn't have annoying prefix sd[new_key] = sd.pop(key) + # Neta-Lumina support + if "text_encoders.gemma2_2b.logit_scale" in sd: + # remove "text_encoders.gemma2_2b.transformer.model." prefix + filtered_sd = { + k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v + for k, v in sd.items() + if k.startswith("text_encoders.gemma2_2b.transformer.model.") + } + sd = filtered_sd + info = gemma2.load_state_dict(sd, strict=False, assign=True) logger.info(f"Loaded Gemma2: {info}") return gemma2 @@ -173,7 +201,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor: return x - DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = { # Embedding layers "time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight", @@ -211,11 +238,11 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items(): # Handle block-specific patterns - if '().' in diff_key: + if "()." in diff_key: for block_idx in range(num_double_blocks): - block_alpha_key = alpha_key.replace('().', f'{block_idx}.') - block_diff_key = diff_key.replace('().', f'{block_idx}.') - + block_alpha_key = alpha_key.replace("().", f"{block_idx}.") + block_diff_key = diff_key.replace("().", f"{block_idx}.") + # Search for and convert block-specific keys for input_key, value in list(sd.items()): if input_key == block_diff_key: @@ -228,6 +255,5 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict else: print(f"Not found: {diff_key}") - logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format") return new_sd diff --git a/lumina_minimal_inference.py b/lumina_minimal_inference.py index 31362c00..d829616b 100644 --- a/lumina_minimal_inference.py +++ b/lumina_minimal_inference.py @@ -231,13 +231,13 @@ def setup_parser() -> argparse.ArgumentParser: "--cfg_trunc_ratio", type=float, default=0.25, - help="TBD", + help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the last 25% of timesteps will be guided.", ) parser.add_argument( "--renorm_cfg", type=float, default=1.0, - help="TBD", + help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.", ) parser.add_argument( "--use_flash_attn",