diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 735bcced..e0d98c0b 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - bsz, _, H, W = latents.shape + bsz, _, h, w = latents.shape sigmas = None if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": @@ -399,7 +399,10 @@ def get_noisy_model_input_and_timesteps( logits_norm = torch.randn(bsz, device=device) logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling timesteps = logits_norm.sigmoid() - mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2)) + mu = get_lin_function( + y1=get_lin_function(args.min_bucket_reso or min(args.resolution)), + y2=get_lin_function(args.max_bucket_reso or max(args.resolution)), + )((h // 2) * (w // 2)) timesteps = time_shift(mu, 1.0, timesteps) t = timesteps.view(-1, 1, 1, 1)