mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix: validation timestep generation fails on SD/SDXL training
This commit is contained in:
@@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common(
|
|||||||
|
|
||||||
|
|
||||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
||||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
if min_timestep < max_timestep:
|
||||||
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||||
|
else:
|
||||||
|
timesteps = torch.full((b_size,), max_timestep, device="cpu")
|
||||||
timesteps = timesteps.long().to(device)
|
timesteps = timesteps.long().to(device)
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user