mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix: validation timestep generation fails on SD/SDXL training
This commit is contained in:
@@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common(
|
||||
|
||||
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
if min_timestep < max_timestep:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||
else:
|
||||
timesteps = torch.full((b_size,), max_timestep, device="cpu")
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps
|
||||
|
||||
|
||||
Reference in New Issue
Block a user