mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
add fine tuning FLUX.1 (WIP)
This commit is contained in:
@@ -274,85 +274,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
):
|
||||
# copy from sd3_train.py and modified
|
||||
|
||||
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
||||
sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
||||
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device)
|
||||
timesteps = timesteps.to(accelerator.device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=accelerator.device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=accelerator.device)
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
||||
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
@@ -425,20 +354,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
if args.model_prediction_type == "raw":
|
||||
# use model_pred as is
|
||||
weighting = None
|
||||
elif args.model_prediction_type == "additive":
|
||||
# add the model_pred to the noisy_model_input
|
||||
model_pred = model_pred + noisy_model_input
|
||||
weighting = None
|
||||
elif args.model_prediction_type == "sigma_scaled":
|
||||
# apply sigma scaling
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
@@ -469,83 +386,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = train_network.setup_parser()
|
||||
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||
parser.add_argument("--clip_l", type=str, help="path to clip_l")
|
||||
parser.add_argument("--t5xxl", type=str, help="path to t5xxl")
|
||||
parser.add_argument("--ae", type=str, help="path to ae")
|
||||
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs_to_disk",
|
||||
action="store_true",
|
||||
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
||||
)
|
||||
flux_train_utils.add_flux_train_arguments(parser)
|
||||
|
||||
parser.add_argument(
|
||||
"--split_mode",
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
||||
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
type=int,
|
||||
default=None,
|
||||
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
||||
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
||||
)
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
"--weighting_scheme",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||
)
|
||||
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
||||
parser.add_argument(
|
||||
"--mode_scale",
|
||||
type=float,
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=3.5,
|
||||
help="the FLUX.1 dev variant is a guidance distilled model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--timestep_sampling",
|
||||
choices=["sigma", "uniform", "sigmoid"],
|
||||
default="sigma",
|
||||
help="Method to sample timesteps: sigma-based, uniform random, or sigmoid of random normal. / タイムステップをサンプリングする方法:sigma、random uniform、またはrandom normalのsigmoid。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sigmoid_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_prediction_type",
|
||||
choices=["raw", "additive", "sigma_scaled"],
|
||||
default="sigma_scaled",
|
||||
help="How to interpret and process the model prediction: "
|
||||
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
|
||||
" / モデル予測の解釈と処理方法:"
|
||||
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discrete_flow_shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user