mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
rf from bluvoll's fork
This commit is contained in:
@@ -4284,6 +4284,87 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser):
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
|
||||
def add_flux2vae_training_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--vae_reflection_padding",
|
||||
action="store_true",
|
||||
help="switch VAE convolutions to reflection padding (improves border quality for some custom VAEs) / VAEの畳み込みを反射パディングに切り替える",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_custom_scale",
|
||||
type=float,
|
||||
default=None,
|
||||
help="override the latent scaling factor applied after VAE encode / VAEエンコード後のスケーリング係数を上書きする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_custom_shift",
|
||||
type=float,
|
||||
default=None,
|
||||
help="apply a constant latent shift before scaling (e.g. Flux-style offset) / スケーリング前に潜在表現へ定数シフトを適用する",
|
||||
)
|
||||
|
||||
# Dead code. Leave for future use. Hint: bluvoll.
|
||||
raise NotImplementedError("Flux2Vae is not supported in this fork.")
|
||||
|
||||
def add_fm_training_arguments(parser: argparse.ArgumentParser):
|
||||
# Non SD3 / Flux Flow Matching. Code base is different. Extended from add_sd_models_arguments.
|
||||
parser.add_argument(
|
||||
"--flow_model",
|
||||
action="store_true",
|
||||
help="enable Rectified Flow (including Flow Matching) training objective instead of standard diffusion / 通常の拡散ではなくFlow Matchingで学習する",
|
||||
)
|
||||
|
||||
# TODO: Not implemented: --flow_use_ot as OT-CFM
|
||||
|
||||
# Logit-Normal Sampling
|
||||
parser.add_argument(
|
||||
"--flow_timestep_distribution",
|
||||
type=str,
|
||||
default="logit_normal",
|
||||
choices=["logit_normal", "uniform"],
|
||||
help="sampling distribution over Flow Matching sigmas (default: logit_normal) / Flow Matchingのシグマの分布(デフォルトlogit_normal)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flow_logit_mean",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="mean of the logit-normal distribution when using Flow Matching / Flow Matchingでlogit-normal分布を用いるときの平均値",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flow_logit_std",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="stddev of the logit-normal distribution when using Flow Matching / Flow Matchingでlogit-normal分布を用いるときの標準偏差",
|
||||
)
|
||||
|
||||
# Resolution-dependent shifting of timestep schedules
|
||||
# Disabled: --flow_uniform_base_pixels,
|
||||
parser.add_argument(
|
||||
"--flow_uniform_shift",
|
||||
action="store_true",
|
||||
help="apply resolution-dependent shift to Flow Matching timesteps (SD3-style) / Flow Matchingタイムステップに解像度依存のシフトを適用する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flow_uniform_static_ratio",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="set sqrt(m/n) ratio for Resolution-dependent shifting of timestep schedules; set 1.0 to disable / 解像度に依存したタイムステップスケジュールのシフトのsqrt(m/n)比を設定します。無効にするには1.0を設定します。",
|
||||
)
|
||||
|
||||
# CFM, but Contrastive Flow Matching
|
||||
parser.add_argument(
|
||||
"--contrastive_flow_matching",
|
||||
action="store_true",
|
||||
help="Enable Contrastive Flow Matching (ΔFM) objective. Works with v-parameterization or Flow Matching.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cfm_lambda",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Lambda weight for the contrastive term in ΔFM loss (default: 0.05).",
|
||||
)
|
||||
|
||||
# TODO: Not implemented: --use_zero_cond_dropout
|
||||
|
||||
def get_sanitized_config_or_none(args: argparse.Namespace):
|
||||
# if `--log_config` is enabled, return args for logging. if not, return None.
|
||||
@@ -4442,6 +4523,57 @@ def verify_training_args(args: argparse.Namespace):
|
||||
)
|
||||
args.sample_every_n_steps = None
|
||||
|
||||
def verify_fm_training_args(args: argparse.Namespace):
|
||||
# continued from verify_training_args, but specific for flow matching.
|
||||
if not args.flow_model:
|
||||
return
|
||||
logger.info("Using Flow Matching training objective.")
|
||||
if args.v_parameterization:
|
||||
raise ValueError("`--flow_model` is incompatible with `--v_parameterization`; Flow Matching already predicts velocity.")
|
||||
if args.min_snr_gamma:
|
||||
logger.warning("`--min_snr_gamma` is ignored when Flow Matching is enabled.")
|
||||
args.min_snr_gamma = None
|
||||
if args.debiased_estimation_loss:
|
||||
logger.warning("`--debiased_estimation_loss` is ignored when Flow Matching is enabled.")
|
||||
args.debiased_estimation_loss = False
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
logger.warning("`--scale_v_pred_loss_like_noise_pred` is ignored when Flow Matching is enabled.")
|
||||
args.scale_v_pred_loss_like_noise_pred = False
|
||||
if args.v_pred_like_loss:
|
||||
logger.warning("`--v_pred_like_loss` is ignored when Flow Matching is enabled.")
|
||||
args.v_pred_like_loss = None
|
||||
if args.flow_use_ot:
|
||||
logger.info("Using cosine optimal transport pairing for Flow Matching batches.")
|
||||
raise NotImplementedError("`--flow_use_ot` is not available in this fork.")
|
||||
|
||||
shift_enabled = args.flow_uniform_shift or args.flow_uniform_static_ratio is not None
|
||||
distribution = args.flow_timestep_distribution
|
||||
if distribution == "logit_normal":
|
||||
if not args.flow_logit_std > 0.0:
|
||||
raise ValueError("`--flow_logit_std` must be positive.")
|
||||
if args.flow_logit_mean is None:
|
||||
raise ValueError("`--flow_logit_mean` must present.")
|
||||
logger.info(
|
||||
"Flow Matching timesteps sampled from logit-normal distribution with "
|
||||
f"mean={args.flow_logit_mean}, std={args.flow_logit_std}."
|
||||
)
|
||||
elif distribution == "uniform":
|
||||
logger.info("Flow Matching timesteps sampled uniformly in [0, 1].")
|
||||
else:
|
||||
raise ValueError(f"Unknown Flow Matching timestep distribution: {distribution}")
|
||||
|
||||
if shift_enabled:
|
||||
if args.flow_uniform_static_ratio is not None:
|
||||
if not args.flow_uniform_static_ratio > 0.0:
|
||||
raise ValueError("`--flow_uniform_static_ratio` must be positive.")
|
||||
logger.info(f"Flow Matching timestep shift uses static ratio={args.flow_uniform_static_ratio}.")
|
||||
else:
|
||||
raise NotImplementedError("`--flow_uniform_base_pixels` is not available in this fork. Set --flow_uniform_static_ratio=1.0 instead.")
|
||||
|
||||
if args.contrastive_flow_matching:
|
||||
if args.cfm_lambda is None:
|
||||
raise ValueError("`--cfm_lambda` must present.")
|
||||
logger.info(f"Contrastive Flow Matching is enabled with cfm_lambda={args.cfm_lambda}.")
|
||||
|
||||
def add_dataset_arguments(
|
||||
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
|
||||
@@ -6085,7 +6217,10 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps
|
||||
|
||||
|
||||
# 260125: It worths for a extended functions dedicated for flow matching, which requires a custom timestep mechanism.
|
||||
# 260125: It should combines the idea in get_noisy_model_input_and_timesteps as well.
|
||||
# 260125: For code pattern, assume args are valid already, do not set default value. Raise exceptions instead.
|
||||
# 260125: Original code seek for huber_c also, but I have already seperated it to later stage.
|
||||
def get_noise_noisy_latents_and_timesteps(
|
||||
args, noise_scheduler, latents: torch.FloatTensor
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||
@@ -6107,22 +6242,64 @@ def get_noise_noisy_latents_and_timesteps(
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
|
||||
flow_model_enabled = args.flow_model
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
|
||||
if flow_model_enabled:
|
||||
# Flow Matching. Timestep is calculated here, through sigma.
|
||||
# Reason for max_timestep -= 1 is not known yet.
|
||||
max_timestep = max_timestep - 1
|
||||
# distribution refers to SD3's weighting_scheme such as logit_normal and cosmap.
|
||||
distribution = args.flow_timestep_distribution
|
||||
if distribution == "logit_normal":
|
||||
logits = torch.normal(
|
||||
mean=args.flow_logit_mean,
|
||||
std=args.flow_logit_std,
|
||||
size=(b_size,),
|
||||
device=latents.device,
|
||||
)
|
||||
sigmas = torch.sigmoid(logits)
|
||||
elif distribution == "uniform":
|
||||
sigmas = torch.rand((b_size,), device=latents.device)
|
||||
else:
|
||||
strength = args.ip_noise_gamma
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
|
||||
raise ValueError(f"Unknown flow_timestep_distribution: {distribution}")
|
||||
|
||||
# Resolution-dependent shifting of timestep schedules. No need to check independently. If we really calaulates m/n under ARB, it will be always close to 1.0.
|
||||
shift_requested = args.flow_uniform_shift
|
||||
static_ratio = args.flow_uniform_static_ratio
|
||||
|
||||
if sigmas is None:
|
||||
raise ValueError("FM: sigmas is None. Should not happens.")
|
||||
|
||||
if shift_requested:
|
||||
if not static_ratio > 0.0:
|
||||
raise ValueError("Invalid ratio. Must be a postitive number.")
|
||||
ratios = torch.full((b_size,), float(static_ratio), device=latents.device, dtype=torch.float32)
|
||||
t_ref = sigmas
|
||||
sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref)
|
||||
|
||||
timesteps = torch.clamp((sigmas * max_timestep).long(), 0, max_timestep)
|
||||
|
||||
# Forward for Flow Matching. TODO: args.flow_use_ot
|
||||
sigmas_view = sigmas.view(-1, 1, 1, 1)
|
||||
# Note: add_noise has been done as sigma
|
||||
noisy_latents = sigmas_view * noise + (1.0 - sigmas_view) * latents
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
# It is called Score Matching. Codes are original.
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
|
||||
else:
|
||||
strength = args.ip_noise_gamma
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
return noise, noisy_latents, timesteps
|
||||
|
||||
|
||||
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
|
||||
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
|
||||
return None
|
||||
|
||||
@@ -233,6 +233,7 @@ class NativeTrainer:
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
# Note: Flow Matching will bend the timesteps
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
@@ -255,10 +256,16 @@ class NativeTrainer:
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
if args.flow_model:
|
||||
# Rectified Flow. Kind of vpred. Math is fun.
|
||||
target = noise - latents
|
||||
elif args.v_parameterization:
|
||||
# v-parameterization training
|
||||
# velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
# target = (alphas_cumprod[timesteps] ** 0.5) * noise - (1 - alphas_cumprod[timesteps]) ** 0.5 * latents
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
# EPS mode
|
||||
target = noise
|
||||
|
||||
# differential output preservation
|
||||
@@ -283,7 +290,10 @@ class NativeTrainer:
|
||||
)
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
# weighting is unused unless cosmap is used (See SD3 / Flux).
|
||||
weighting = None
|
||||
# noise is used for Contrastive Flow Matching.
|
||||
return noise_pred, target, timesteps, weighting, noise
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
@@ -439,7 +449,7 @@ class NativeTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
noise_pred, target, timesteps, weighting, noise = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -456,6 +466,14 @@ class NativeTrainer:
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.flow_model and args.contrastive_flow_matching and latents.size(0) > 1:
|
||||
# Original code accepts vpred, which is strange.
|
||||
negative_latents = latents.roll(1, 0)
|
||||
negative_noise = noise.roll(1, 0)
|
||||
#with torch.no_grad():
|
||||
target_negative = negative_noise - negative_latents
|
||||
loss_contrastive = torch.nn.functional.mse_loss(noise_pred.float(), target_negative.float(), reduction="none")
|
||||
loss = loss - args.cfm_lambda * loss_contrastive
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
@@ -465,6 +483,10 @@ class NativeTrainer:
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
# From Flow Matching code. So strange.
|
||||
# if loss.ndim != 0:
|
||||
# loss = loss.mean()
|
||||
|
||||
return loss.mean()
|
||||
|
||||
def train(self, args):
|
||||
@@ -472,6 +494,7 @@ class NativeTrainer:
|
||||
session_id = random.randint(0, 2**32)
|
||||
training_started_at = time.time()
|
||||
train_util.verify_training_args(args)
|
||||
train_util.verify_fm_training_args(args)
|
||||
train_util.prepare_dataset_args(args, True)
|
||||
if args.skip_cache_check:
|
||||
train_util.set_skip_npz_path_check(True)
|
||||
@@ -1686,6 +1709,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_fm_training_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
|
||||
Reference in New Issue
Block a user