fix: update flow matching loss and variable names

This commit is contained in:
Kohya S
2025-07-13 21:26:06 +09:00
parent 30295c9668
commit 13ccfc39f8

View File

@@ -294,7 +294,7 @@ def train(args):
# load lumina
nextdit = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
loading_dtype,
weight_dtype,
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
@@ -494,6 +494,8 @@ def train(args):
clean_memory_on_device(accelerator.device)
is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
@@ -739,7 +741,7 @@ def train(args):
with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=img, # image latents (B, C, H, W)
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
@@ -751,8 +753,8 @@ def train(args):
args, model_pred, noisy_model_input, sigmas
)
# flow matching loss: this is different from SD3
target = noise - latents
# flow matching loss
target = latents - noise
# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(