mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Update lumina_train_network.py
This commit is contained in:
@@ -268,7 +268,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||||
model_pred = dit(
|
model_pred = dit(
|
||||||
x=img, # image latents (B, C, H, W)
|
x=img, # image latents (B, C, H, W)
|
||||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
t= 1 - timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||||
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user