mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Update flux_train_utils.py
This commit is contained in:
@@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, H, W = latents.shape
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
@@ -399,7 +399,10 @@ def get_noisy_model_input_and_timesteps(
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2))
|
||||
mu = get_lin_function(
|
||||
y1=get_lin_function(args.min_bucket_reso or min(args.resolution)),
|
||||
y2=get_lin_function(args.max_bucket_reso or max(args.resolution)),
|
||||
)((h // 2) * (w // 2))
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
|
||||
Reference in New Issue
Block a user