Update flux_train_utils.py

This commit is contained in:
sdbds
2024-08-31 03:05:19 +08:00
parent 8fdfd8c857
commit 25c9040f4f

View File

@@ -371,7 +371,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps( def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz = latents.shape[0] bsz, _, H, W = latents.shape
sigmas = None sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
@@ -392,6 +392,16 @@ def get_noisy_model_input_and_timesteps(
timesteps = logits_norm.sigmoid() timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu=get_lin_function(y1=0.5, y2=1.15)((H//2) * (W//2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1) t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0 timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise noisy_model_input = (1 - t) * latents + t * noise
@@ -571,7 +581,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--timestep_sampling", "--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift"], choices=["sigma", "uniform", "sigmoid", "shift", "flux_shift"],
default="sigma", default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid." help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト。", " / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",