Compare commits

...

4 Commits

Author SHA1 Message Date
Kohya S
c0c36a4e2f fix: remove duplicated latent normalization in decoding 2025-07-15 21:58:03 +09:00
Kohya S
25771a5180 fix: update help text for cfg_trunc_ratio argument 2025-07-15 21:53:13 +09:00
Kohya S
e0fcb5152a feat: support Neta Lumina all-in-one weights 2025-07-15 21:34:35 +09:00
Kohya S
13ccfc39f8 fix: update flow matching loss and variable names 2025-07-13 21:26:06 +09:00
3 changed files with 42 additions and 14 deletions

View File

@@ -44,10 +44,21 @@ def load_lumina_model(
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(dtype)
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
dtype
)
logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
# Neta-Lumina support
if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
# remove "model.diffusion_model." prefix
filtered_state_dict = {
k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
}
state_dict = filtered_state_dict
info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
return model
@@ -78,6 +89,13 @@ def load_ae(
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)
# Neta-Lumina support
if "vae.decoder.conv_in.bias" in sd:
# remove "vae." prefix
filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
sd = filtered_sd
info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae
@@ -152,6 +170,16 @@ def load_gemma2(
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
# Neta-Lumina support
if "text_encoders.gemma2_2b.logit_scale" in sd:
# remove "text_encoders.gemma2_2b.transformer.model." prefix
filtered_sd = {
k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
for k, v in sd.items()
if k.startswith("text_encoders.gemma2_2b.transformer.model.")
}
sd = filtered_sd
info = gemma2.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Gemma2: {info}")
return gemma2
@@ -173,7 +201,6 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
return x
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
# Embedding layers
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
@@ -211,11 +238,11 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
# Handle block-specific patterns
if '().' in diff_key:
if "()." in diff_key:
for block_idx in range(num_double_blocks):
block_alpha_key = alpha_key.replace('().', f'{block_idx}.')
block_diff_key = diff_key.replace('().', f'{block_idx}.')
block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
block_diff_key = diff_key.replace("().", f"{block_idx}.")
# Search for and convert block-specific keys
for input_key, value in list(sd.items()):
if input_key == block_diff_key:
@@ -228,6 +255,5 @@ def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict
else:
print(f"Not found: {diff_key}")
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
return new_sd

View File

@@ -158,7 +158,7 @@ def generate_image(
# 5. Decode latents
#
logger.info("Decoding image...")
latents = latents / ae.scale_factor + ae.shift_factor
# latents = latents / ae.scale_factor + ae.shift_factor
with torch.no_grad():
image = ae.decode(latents.to(ae_dtype))
image = (image / 2 + 0.5).clamp(0, 1)
@@ -231,13 +231,13 @@ def setup_parser() -> argparse.ArgumentParser:
"--cfg_trunc_ratio",
type=float,
default=0.25,
help="TBD",
help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25% of timesteps will be guided.",
)
parser.add_argument(
"--renorm_cfg",
type=float,
default=1.0,
help="TBD",
help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.",
)
parser.add_argument(
"--use_flash_attn",

View File

@@ -294,7 +294,7 @@ def train(args):
# load lumina
nextdit = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
loading_dtype,
weight_dtype,
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
@@ -494,6 +494,8 @@ def train(args):
clean_memory_on_device(accelerator.device)
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
@@ -739,7 +741,7 @@ def train(args):
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=img, # image latents (B, C, H, W)
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
@@ -751,8 +753,8 @@ def train(args):
args, model_pred, noisy_model_input, sigmas
)
# flow matching loss: this is different from SD3
target = noise - latents
# flow matching loss
target = latents - noise
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(