rf from bluvoll's fork

This commit is contained in:
Darren Laurie
2026-02-01 23:48:59 +08:00
parent 132f0b6a15
commit 6380648f85
2 changed files with 215 additions and 14 deletions

View File

@@ -4284,6 +4284,87 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser):
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
def add_flux2vae_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--vae_reflection_padding",
action="store_true",
help="switch VAE convolutions to reflection padding (improves border quality for some custom VAEs) / VAEの畳み込みを反射パディングに切り替える",
)
parser.add_argument(
"--vae_custom_scale",
type=float,
default=None,
help="override the latent scaling factor applied after VAE encode / VAEエンコード後のスケーリング係数を上書きする",
)
parser.add_argument(
"--vae_custom_shift",
type=float,
default=None,
help="apply a constant latent shift before scaling (e.g. Flux-style offset) / スケーリング前に潜在表現へ定数シフトを適用する",
)
# Dead code. Leave for future use. Hint: bluvoll.
raise NotImplementedError("Flux2Vae is not supported in this fork.")
def add_fm_training_arguments(parser: argparse.ArgumentParser):
# Non SD3 / Flux Flow Matching. Code base is different. Extended from add_sd_models_arguments.
parser.add_argument(
"--flow_model",
action="store_true",
help="enable Rectified Flow (including Flow Matching) training objective instead of standard diffusion / 通常の拡散ではなくFlow Matchingで学習する",
)
# TODO: Not implemented: --flow_use_ot as OT-CFM
# Logit-Normal Sampling
parser.add_argument(
"--flow_timestep_distribution",
type=str,
default="logit_normal",
choices=["logit_normal", "uniform"],
help="sampling distribution over Flow Matching sigmas (default: logit_normal) / Flow Matchingのシグマの分布デフォルトlogit_normal",
)
parser.add_argument(
"--flow_logit_mean",
type=float,
default=0.0,
help="mean of the logit-normal distribution when using Flow Matching / Flow Matchingでlogit-normal分布を用いるときの平均値",
)
parser.add_argument(
"--flow_logit_std",
type=float,
default=1.0,
help="stddev of the logit-normal distribution when using Flow Matching / Flow Matchingでlogit-normal分布を用いるときの標準偏差",
)
# Resolution-dependent shifting of timestep schedules
# Disabled: --flow_uniform_base_pixels,
parser.add_argument(
"--flow_uniform_shift",
action="store_true",
help="apply resolution-dependent shift to Flow Matching timesteps (SD3-style) / Flow Matchingタイムステップに解像度依存のシフトを適用する",
)
parser.add_argument(
"--flow_uniform_static_ratio",
type=float,
default=3.0,
help="set sqrt(m/n) ratio for Resolution-dependent shifting of timestep schedules; set 1.0 to disable / 解像度に依存したタイムステップスケジュールのシフトのsqrt(m/n)比を設定します。無効にするには1.0を設定します。",
)
# CFM, but Contrastive Flow Matching
parser.add_argument(
"--contrastive_flow_matching",
action="store_true",
help="Enable Contrastive Flow Matching (ΔFM) objective. Works with v-parameterization or Flow Matching.",
)
parser.add_argument(
"--cfm_lambda",
type=float,
default=0.05,
help="Lambda weight for the contrastive term in ΔFM loss (default: 0.05).",
)
# TODO: Not implemented: --use_zero_cond_dropout
def get_sanitized_config_or_none(args: argparse.Namespace):
# if `--log_config` is enabled, return args for logging. if not, return None.
@@ -4442,6 +4523,57 @@ def verify_training_args(args: argparse.Namespace):
)
args.sample_every_n_steps = None
def verify_fm_training_args(args: argparse.Namespace):
# continued from verify_training_args, but specific for flow matching.
if not args.flow_model:
return
logger.info("Using Flow Matching training objective.")
if args.v_parameterization:
raise ValueError("`--flow_model` is incompatible with `--v_parameterization`; Flow Matching already predicts velocity.")
if args.min_snr_gamma:
logger.warning("`--min_snr_gamma` is ignored when Flow Matching is enabled.")
args.min_snr_gamma = None
if args.debiased_estimation_loss:
logger.warning("`--debiased_estimation_loss` is ignored when Flow Matching is enabled.")
args.debiased_estimation_loss = False
if args.scale_v_pred_loss_like_noise_pred:
logger.warning("`--scale_v_pred_loss_like_noise_pred` is ignored when Flow Matching is enabled.")
args.scale_v_pred_loss_like_noise_pred = False
if args.v_pred_like_loss:
logger.warning("`--v_pred_like_loss` is ignored when Flow Matching is enabled.")
args.v_pred_like_loss = None
if args.flow_use_ot:
logger.info("Using cosine optimal transport pairing for Flow Matching batches.")
raise NotImplementedError("`--flow_use_ot` is not available in this fork.")
shift_enabled = args.flow_uniform_shift or args.flow_uniform_static_ratio is not None
distribution = args.flow_timestep_distribution
if distribution == "logit_normal":
if not args.flow_logit_std > 0.0:
raise ValueError("`--flow_logit_std` must be positive.")
if args.flow_logit_mean is None:
raise ValueError("`--flow_logit_mean` must present.")
logger.info(
"Flow Matching timesteps sampled from logit-normal distribution with "
f"mean={args.flow_logit_mean}, std={args.flow_logit_std}."
)
elif distribution == "uniform":
logger.info("Flow Matching timesteps sampled uniformly in [0, 1].")
else:
raise ValueError(f"Unknown Flow Matching timestep distribution: {distribution}")
if shift_enabled:
if args.flow_uniform_static_ratio is not None:
if not args.flow_uniform_static_ratio > 0.0:
raise ValueError("`--flow_uniform_static_ratio` must be positive.")
logger.info(f"Flow Matching timestep shift uses static ratio={args.flow_uniform_static_ratio}.")
else:
raise NotImplementedError("`--flow_uniform_base_pixels` is not available in this fork. Set --flow_uniform_static_ratio=1.0 instead.")
if args.contrastive_flow_matching:
if args.cfm_lambda is None:
raise ValueError("`--cfm_lambda` must present.")
logger.info(f"Contrastive Flow Matching is enabled with cfm_lambda={args.cfm_lambda}.")
def add_dataset_arguments(
parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool
@@ -6085,7 +6217,10 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor
timesteps = timesteps.long().to(device)
return timesteps
# 260125: It worths for a extended functions dedicated for flow matching, which requires a custom timestep mechanism.
# 260125: It should combines the idea in get_noisy_model_input_and_timesteps as well.
# 260125: For code pattern, assume args are valid already, do not set default value. Raise exceptions instead.
# 260125: Original code seek for huber_c also, but I have already seperated it to later stage.
def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
@@ -6107,22 +6242,64 @@ def get_noise_noisy_latents_and_timesteps(
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
flow_model_enabled = args.flow_model
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
if flow_model_enabled:
# Flow Matching. Timestep is calculated here, through sigma.
# Reason for max_timestep -= 1 is not known yet.
max_timestep = max_timestep - 1
# distribution refers to SD3's weighting_scheme such as logit_normal and cosmap.
distribution = args.flow_timestep_distribution
if distribution == "logit_normal":
logits = torch.normal(
mean=args.flow_logit_mean,
std=args.flow_logit_std,
size=(b_size,),
device=latents.device,
)
sigmas = torch.sigmoid(logits)
elif distribution == "uniform":
sigmas = torch.rand((b_size,), device=latents.device)
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
raise ValueError(f"Unknown flow_timestep_distribution: {distribution}")
# Resolution-dependent shifting of timestep schedules. No need to check independently. If we really calaulates m/n under ARB, it will be always close to 1.0.
shift_requested = args.flow_uniform_shift
static_ratio = args.flow_uniform_static_ratio
if sigmas is None:
raise ValueError("FM: sigmas is None. Should not happens.")
if shift_requested:
if not static_ratio > 0.0:
raise ValueError("Invalid ratio. Must be a postitive number.")
ratios = torch.full((b_size,), float(static_ratio), device=latents.device, dtype=torch.float32)
t_ref = sigmas
sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref)
timesteps = torch.clamp((sigmas * max_timestep).long(), 0, max_timestep)
# Forward for Flow Matching. TODO: args.flow_use_ot
sigmas_view = sigmas.view(-1, 1, 1, 1)
# Note: add_noise has been done as sigma
noisy_latents = sigmas_view * noise + (1.0 - sigmas_view) * latents
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# It is called Score Matching. Codes are original.
timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
if args.ip_noise_gamma_random_strength:
strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma
else:
strength = args.ip_noise_gamma
noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps)
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
return noise, noisy_latents, timesteps
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
return None

View File

@@ -233,6 +233,7 @@ class NativeTrainer:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
# Note: Flow Matching will bend the timesteps
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
# ensure the hidden state will require grad
@@ -255,10 +256,16 @@ class NativeTrainer:
weight_dtype,
)
if args.v_parameterization:
if args.flow_model:
# Rectified Flow. Kind of vpred. Math is fun.
target = noise - latents
elif args.v_parameterization:
# v-parameterization training
# velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
# target = (alphas_cumprod[timesteps] ** 0.5) * noise - (1 - alphas_cumprod[timesteps]) ** 0.5 * latents
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
# EPS mode
target = noise
# differential output preservation
@@ -283,7 +290,10 @@ class NativeTrainer:
)
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, None
# weighting is unused unless cosmap is used (See SD3 / Flux).
weighting = None
# noise is used for Contrastive Flow Matching.
return noise_pred, target, timesteps, weighting, noise
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
@@ -439,7 +449,7 @@ class NativeTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, target, timesteps, weighting, noise = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -456,6 +466,14 @@ class NativeTrainer:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.flow_model and args.contrastive_flow_matching and latents.size(0) > 1:
# Original code accepts vpred, which is strange.
negative_latents = latents.roll(1, 0)
negative_noise = noise.roll(1, 0)
#with torch.no_grad():
target_negative = negative_noise - negative_latents
loss_contrastive = torch.nn.functional.mse_loss(noise_pred.float(), target_negative.float(), reduction="none")
loss = loss - args.cfm_lambda * loss_contrastive
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
@@ -465,6 +483,10 @@ class NativeTrainer:
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
# From Flow Matching code. So strange.
# if loss.ndim != 0:
# loss = loss.mean()
return loss.mean()
def train(self, args):
@@ -472,6 +494,7 @@ class NativeTrainer:
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args)
train_util.verify_fm_training_args(args)
train_util.prepare_dataset_args(args, True)
if args.skip_cache_check:
train_util.set_skip_npz_path_check(True)
@@ -1686,6 +1709,7 @@ def setup_parser() -> argparse.ArgumentParser:
add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
train_util.add_fm_training_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)