diff --git a/lumina_train_network.py b/lumina_train_network.py index b08e3143..095bca24 100644 --- a/lumina_train_network.py +++ b/lumina_train_network.py @@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer): # NextDiT forward expects (x, t, cap_feats, cap_mask) model_pred = dit( x=img, # image latents (B, C, H, W) - t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期 + t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期 cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask )