fix: validation timestep generation fails on SD/SDXL training

This commit is contained in:
Kohya S
2025-02-04 22:02:42 +09:00
parent c5b803ce94
commit a24db1d532

View File

@@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common(
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.full((b_size,), max_timestep, device="cpu")
timesteps = timesteps.long().to(device)
return timesteps