feat: support 5d tensor in get_noisy_model_input_and_timesteps

This commit is contained in:
kohya-ss
2026-02-09 12:43:37 +09:00
parent 2774e7757b
commit bf9806648c

View File

@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
bsz, h, w = latents.shape[0], latents.shape[-2], latents.shape[-1]
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -512,7 +512,7 @@ def get_noisy_model_input_and_timesteps(
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
sigmas = sigmas.view(-1, 1, 1, 1) if latents.ndim == 4 else sigmas.view(-1, 1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)