mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: Add shift option to --timestep_sampling in FLUX.1 fine-tuning and LoRA training
This commit is contained in:
@@ -9,6 +9,10 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 25, 2024:
|
||||||
|
Added `shift` option to `--timestep_sampling` in FLUX.1 fine-tuning and LoRA training. Shifts timesteps according to the value of `--discrete_flow_shift` (shifts the value of sigmoid of normal distribution random number). It may be good to start with a value of 3.1582 (=e^1.15) for `--discrete_flow_shift`.
|
||||||
|
Sample command: `--timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0`
|
||||||
|
|
||||||
Aug 24, 2024 (update 2):
|
Aug 24, 2024 (update 2):
|
||||||
|
|
||||||
__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available).
|
__Experimental__ Added an option to split the projection layers of q/k/v/txt in the attention and apply LoRA to each of them in FLUX.1 LoRA training. Specify `"split_qkv=True"` in network_args like `--network_args "split_qkv=True"` (`train_blocks` is also available).
|
||||||
|
|||||||
@@ -380,9 +380,19 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||||
else:
|
else:
|
||||||
t = torch.rand((bsz,), device=device)
|
t = torch.rand((bsz,), device=device)
|
||||||
|
|
||||||
timesteps = t * 1000.0
|
timesteps = t * 1000.0
|
||||||
t = t.view(-1, 1, 1, 1)
|
t = t.view(-1, 1, 1, 1)
|
||||||
noisy_model_input = (1 - t) * latents + t * noise
|
noisy_model_input = (1 - t) * latents + t * noise
|
||||||
|
elif args.timestep_sampling == "shift":
|
||||||
|
shift = args.discrete_flow_shift
|
||||||
|
logits_norm = torch.randn(bsz, device=device)
|
||||||
|
timesteps = logits_norm.sigmoid()
|
||||||
|
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||||
|
|
||||||
|
t = timesteps.view(-1, 1, 1, 1)
|
||||||
|
timesteps = timesteps * 1000.0
|
||||||
|
noisy_model_input = (1 - t) * latents + t * noise
|
||||||
else:
|
else:
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
# for weighting schemes where we sample timesteps non-uniformly
|
# for weighting schemes where we sample timesteps non-uniformly
|
||||||
@@ -559,9 +569,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--timestep_sampling",
|
"--timestep_sampling",
|
||||||
choices=["sigma", "uniform", "sigmoid"],
|
choices=["sigma", "uniform", "sigmoid", "shift"],
|
||||||
default="sigma",
|
default="sigma",
|
||||||
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。",
|
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal and shift of sigmoid."
|
||||||
|
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト。",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sigmoid_scale",
|
"--sigmoid_scale",
|
||||||
|
|||||||
Reference in New Issue
Block a user