feat: Add shift option to --timestep_sampling in FLUX.1 fine-tuning and LoRA training

This commit is contained in:
Kohya S
2024-08-25 16:01:24 +09:00
parent ea9242653c
commit 72287d39c7
2 changed files with 17 additions and 2 deletions

View File

@@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps(
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid"],
choices=["sigma", "uniform", "sigmoid", "shift"],
default="sigma",
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法sigma、random uniform、またはrandom normalのsigmoid",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
" / タイムステップをサンプリングする方法sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
)
parser.add_argument(
"--sigmoid_scale",