Update lumina_train_network.py

This commit is contained in:
duongve13112002
2025-09-29 19:47:32 +07:00
committed by GitHub
parent f69a8f9e53
commit b32d66cfd2

View File

@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)