diff --git a/library/train_util.py b/library/train_util.py index 37ed0a99..01fa6467 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5935,7 +5935,10 @@ def save_sd_model_on_train_end_common( def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + if min_timestep < max_timestep: + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") + else: + timesteps = torch.full((b_size,), max_timestep, device="cpu") timesteps = timesteps.long().to(device) return timesteps