feat: fix timestep for input_vec for Chroma

This commit is contained in:
Kohya S
2025-07-20 20:53:06 +09:00
parent b4e862626a
commit 0b763ef1f1
3 changed files with 34 additions and 9 deletions

View File

@@ -341,9 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# get modulation vectors for Chroma
input_vec = None
if self.model_type == "chroma":
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)