diff --git a/library/train_util.py b/library/train_util.py index ab0d31b2..4e53ade5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4284,6 +4284,87 @@ def add_dit_training_arguments(parser: argparse.ArgumentParser): "この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。", ) +def add_flux2vae_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--vae_reflection_padding", + action="store_true", + help="switch VAE convolutions to reflection padding (improves border quality for some custom VAEs) / VAEの畳み込みを反射パディングに切り替える", + ) + parser.add_argument( + "--vae_custom_scale", + type=float, + default=None, + help="override the latent scaling factor applied after VAE encode / VAEエンコード後のスケーリング係数を上書きする", + ) + parser.add_argument( + "--vae_custom_shift", + type=float, + default=None, + help="apply a constant latent shift before scaling (e.g. Flux-style offset) / スケーリング前に潜在表現へ定数シフトを適用する", + ) + + # Dead code. Leave for future use. Hint: bluvoll. + raise NotImplementedError("Flux2Vae is not supported in this fork.") + +def add_fm_training_arguments(parser: argparse.ArgumentParser): + # Non SD3 / Flux Flow Matching. Code base is different. Extended from add_sd_models_arguments. + parser.add_argument( + "--flow_model", + action="store_true", + help="enable Rectified Flow (including Flow Matching) training objective instead of standard diffusion / 通常の拡散ではなくFlow Matchingで学習する", + ) + + # TODO: Not implemented: --flow_use_ot as OT-CFM + + # Logit-Normal Sampling + parser.add_argument( + "--flow_timestep_distribution", + type=str, + default="logit_normal", + choices=["logit_normal", "uniform"], + help="sampling distribution over Flow Matching sigmas (default: logit_normal) / Flow Matchingのシグマの分布(デフォルトlogit_normal)", + ) + parser.add_argument( + "--flow_logit_mean", + type=float, + default=0.0, + help="mean of the logit-normal distribution when using Flow Matching / Flow Matchingでlogit-normal分布を用いるときの平均値", + ) + parser.add_argument( + "--flow_logit_std", + type=float, + default=1.0, + help="stddev of the logit-normal distribution when using Flow Matching / Flow Matchingでlogit-normal分布を用いるときの標準偏差", + ) + + # Resolution-dependent shifting of timestep schedules + # Disabled: --flow_uniform_base_pixels, + parser.add_argument( + "--flow_uniform_shift", + action="store_true", + help="apply resolution-dependent shift to Flow Matching timesteps (SD3-style) / Flow Matchingタイムステップに解像度依存のシフトを適用する", + ) + parser.add_argument( + "--flow_uniform_static_ratio", + type=float, + default=3.0, + help="set sqrt(m/n) ratio for Resolution-dependent shifting of timestep schedules; set 1.0 to disable / 解像度に依存したタイムステップスケジュールのシフトのsqrt(m/n)比を設定します。無効にするには1.0を設定します。", + ) + + # CFM, but Contrastive Flow Matching + parser.add_argument( + "--contrastive_flow_matching", + action="store_true", + help="Enable Contrastive Flow Matching (ΔFM) objective. Works with v-parameterization or Flow Matching.", + ) + parser.add_argument( + "--cfm_lambda", + type=float, + default=0.05, + help="Lambda weight for the contrastive term in ΔFM loss (default: 0.05).", + ) + + # TODO: Not implemented: --use_zero_cond_dropout def get_sanitized_config_or_none(args: argparse.Namespace): # if `--log_config` is enabled, return args for logging. if not, return None. @@ -4442,6 +4523,57 @@ def verify_training_args(args: argparse.Namespace): ) args.sample_every_n_steps = None +def verify_fm_training_args(args: argparse.Namespace): + # continued from verify_training_args, but specific for flow matching. + if not args.flow_model: + return + logger.info("Using Flow Matching training objective.") + if args.v_parameterization: + raise ValueError("`--flow_model` is incompatible with `--v_parameterization`; Flow Matching already predicts velocity.") + if args.min_snr_gamma: + logger.warning("`--min_snr_gamma` is ignored when Flow Matching is enabled.") + args.min_snr_gamma = None + if args.debiased_estimation_loss: + logger.warning("`--debiased_estimation_loss` is ignored when Flow Matching is enabled.") + args.debiased_estimation_loss = False + if args.scale_v_pred_loss_like_noise_pred: + logger.warning("`--scale_v_pred_loss_like_noise_pred` is ignored when Flow Matching is enabled.") + args.scale_v_pred_loss_like_noise_pred = False + if args.v_pred_like_loss: + logger.warning("`--v_pred_like_loss` is ignored when Flow Matching is enabled.") + args.v_pred_like_loss = None + if args.flow_use_ot: + logger.info("Using cosine optimal transport pairing for Flow Matching batches.") + raise NotImplementedError("`--flow_use_ot` is not available in this fork.") + + shift_enabled = args.flow_uniform_shift or args.flow_uniform_static_ratio is not None + distribution = args.flow_timestep_distribution + if distribution == "logit_normal": + if not args.flow_logit_std > 0.0: + raise ValueError("`--flow_logit_std` must be positive.") + if args.flow_logit_mean is None: + raise ValueError("`--flow_logit_mean` must present.") + logger.info( + "Flow Matching timesteps sampled from logit-normal distribution with " + f"mean={args.flow_logit_mean}, std={args.flow_logit_std}." + ) + elif distribution == "uniform": + logger.info("Flow Matching timesteps sampled uniformly in [0, 1].") + else: + raise ValueError(f"Unknown Flow Matching timestep distribution: {distribution}") + + if shift_enabled: + if args.flow_uniform_static_ratio is not None: + if not args.flow_uniform_static_ratio > 0.0: + raise ValueError("`--flow_uniform_static_ratio` must be positive.") + logger.info(f"Flow Matching timestep shift uses static ratio={args.flow_uniform_static_ratio}.") + else: + raise NotImplementedError("`--flow_uniform_base_pixels` is not available in this fork. Set --flow_uniform_static_ratio=1.0 instead.") + + if args.contrastive_flow_matching: + if args.cfm_lambda is None: + raise ValueError("`--cfm_lambda` must present.") + logger.info(f"Contrastive Flow Matching is enabled with cfm_lambda={args.cfm_lambda}.") def add_dataset_arguments( parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool @@ -6085,7 +6217,10 @@ def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: tor timesteps = timesteps.long().to(device) return timesteps - +# 260125: It worths for a extended functions dedicated for flow matching, which requires a custom timestep mechanism. +# 260125: It should combines the idea in get_noisy_model_input_and_timesteps as well. +# 260125: For code pattern, assume args are valid already, do not set default value. Raise exceptions instead. +# 260125: Original code seek for huber_c also, but I have already seperated it to later stage. def get_noise_noisy_latents_and_timesteps( args, noise_scheduler, latents: torch.FloatTensor ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: @@ -6107,22 +6242,64 @@ def get_noise_noisy_latents_and_timesteps( min_timestep = 0 if args.min_timestep is None else args.min_timestep max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep - timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) + flow_model_enabled = args.flow_model - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - if args.ip_noise_gamma_random_strength: - strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + if flow_model_enabled: + # Flow Matching. Timestep is calculated here, through sigma. + # Reason for max_timestep -= 1 is not known yet. + max_timestep = max_timestep - 1 + # distribution refers to SD3's weighting_scheme such as logit_normal and cosmap. + distribution = args.flow_timestep_distribution + if distribution == "logit_normal": + logits = torch.normal( + mean=args.flow_logit_mean, + std=args.flow_logit_std, + size=(b_size,), + device=latents.device, + ) + sigmas = torch.sigmoid(logits) + elif distribution == "uniform": + sigmas = torch.rand((b_size,), device=latents.device) else: - strength = args.ip_noise_gamma - noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + raise ValueError(f"Unknown flow_timestep_distribution: {distribution}") + + # Resolution-dependent shifting of timestep schedules. No need to check independently. If we really calaulates m/n under ARB, it will be always close to 1.0. + shift_requested = args.flow_uniform_shift + static_ratio = args.flow_uniform_static_ratio + + if sigmas is None: + raise ValueError("FM: sigmas is None. Should not happens.") + + if shift_requested: + if not static_ratio > 0.0: + raise ValueError("Invalid ratio. Must be a postitive number.") + ratios = torch.full((b_size,), float(static_ratio), device=latents.device, dtype=torch.float32) + t_ref = sigmas + sigmas = ratios * t_ref / (1 + (ratios - 1) * t_ref) + + timesteps = torch.clamp((sigmas * max_timestep).long(), 0, max_timestep) + + # Forward for Flow Matching. TODO: args.flow_use_ot + sigmas_view = sigmas.view(-1, 1, 1, 1) + # Note: add_noise has been done as sigma + noisy_latents = sigmas_view * noise + (1.0 - sigmas_view) * latents else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # It is called Score Matching. Codes are original. + timesteps = get_timesteps(min_timestep, max_timestep, b_size, latents.device) + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + if args.ip_noise_gamma: + if args.ip_noise_gamma_random_strength: + strength = torch.rand(1, device=latents.device) * args.ip_noise_gamma + else: + strength = args.ip_noise_gamma + noisy_latents = noise_scheduler.add_noise(latents, noise + strength * torch.randn_like(latents), timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) return noise, noisy_latents, timesteps - def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): return None diff --git a/train_native.py b/train_native.py index 8bb0765c..0efc912e 100644 --- a/train_native.py +++ b/train_native.py @@ -233,6 +233,7 @@ class NativeTrainer: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified + # Note: Flow Matching will bend the timesteps noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) # ensure the hidden state will require grad @@ -255,10 +256,16 @@ class NativeTrainer: weight_dtype, ) - if args.v_parameterization: + if args.flow_model: + # Rectified Flow. Kind of vpred. Math is fun. + target = noise - latents + elif args.v_parameterization: # v-parameterization training + # velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + # target = (alphas_cumprod[timesteps] ** 0.5) * noise - (1 - alphas_cumprod[timesteps]) ** 0.5 * latents target = noise_scheduler.get_velocity(latents, noise, timesteps) else: + # EPS mode target = noise # differential output preservation @@ -283,7 +290,10 @@ class NativeTrainer: ) target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, None + # weighting is unused unless cosmap is used (See SD3 / Flux). + weighting = None + # noise is used for Contrastive Flow Matching. + return noise_pred, target, timesteps, weighting, noise def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -439,7 +449,7 @@ class NativeTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, target, timesteps, weighting, noise = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -456,6 +466,14 @@ class NativeTrainer: loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting + if args.flow_model and args.contrastive_flow_matching and latents.size(0) > 1: + # Original code accepts vpred, which is strange. + negative_latents = latents.roll(1, 0) + negative_noise = noise.roll(1, 0) + #with torch.no_grad(): + target_negative = negative_noise - negative_latents + loss_contrastive = torch.nn.functional.mse_loss(noise_pred.float(), target_negative.float(), reduction="none") + loss = loss - args.cfm_lambda * loss_contrastive if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -465,6 +483,10 @@ class NativeTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + # From Flow Matching code. So strange. + # if loss.ndim != 0: + # loss = loss.mean() + return loss.mean() def train(self, args): @@ -472,6 +494,7 @@ class NativeTrainer: session_id = random.randint(0, 2**32) training_started_at = time.time() train_util.verify_training_args(args) + train_util.verify_fm_training_args(args) train_util.prepare_dataset_args(args, True) if args.skip_cache_check: train_util.set_skip_npz_path_check(True) @@ -1686,6 +1709,7 @@ def setup_parser() -> argparse.ArgumentParser: add_logging_arguments(parser) train_util.add_sd_models_arguments(parser) + train_util.add_fm_training_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) train_util.add_masked_loss_arguments(parser)