diff --git a/flux_train_network.py b/flux_train_network.py index def44155..712d0bc8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -21,6 +21,13 @@ from library import ( strategy_flux, train_util, ) +from library.custom_train_functions import ( + prepare_scheduler_for_custom_training, + apply_snr_weight, + scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, + apply_debiased_estimation, +) from library.utils import setup_logging setup_logging() @@ -326,6 +333,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + prepare_scheduler_for_custom_training(noise_scheduler, device) return noise_scheduler def encode_images_to_latents(self, args, vae, images): @@ -450,7 +458,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return model_pred, target, timesteps, weighting - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss: torch.Tensor, args, timesteps, noise_scheduler) -> torch.FloatTensor: + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) return loss def get_sai_model_spec(self, args): diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index ad3e69ff..2d683693 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -6,6 +6,7 @@ import re from torch.types import Number from typing import List, Optional, Union from .utils import setup_logging +from library import train_util setup_logging() import logging @@ -17,7 +18,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return - alphas_cumprod = noise_scheduler.alphas_cumprod + alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod @@ -66,7 +67,8 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): - snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) @@ -81,9 +83,9 @@ def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: loss = loss * scale return loss - def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info @@ -99,7 +101,12 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + if not hasattr(noise_scheduler, "all_snr"): + return loss + + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: weight = 1 / (snr_t + 1) diff --git a/library/train_util.py b/library/train_util.py index 1f591c42..e92d4518 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5985,9 +5985,11 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): + alphas_cumprod = get_alphas_cumprod(noise_scheduler) + if alphas_cumprod is None: raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + timesteps_indices = index_for_timesteps(timesteps, noise_scheduler) + alphas_cumprod = torch.index_select(alphas_cumprod.to(timesteps.device), 0, timesteps_indices) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c result = result.to(timesteps.device) @@ -5998,6 +6000,64 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler return result +def index_for_timesteps(timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + if hasattr(noise_scheduler, "index_for_timestep"): + noise_scheduler.timesteps = noise_scheduler.timesteps.to(timesteps.device) + # Convert timesteps to appropriate indices using the scheduler's method + indices = [] + for t in timesteps: + # Make sure t is a tensor with the right device + t_tensor = t if isinstance(t, torch.Tensor) else torch.tensor([t], device=timesteps.device)[0] + try: + # Use the scheduler's method to get the correct index + idx = noise_scheduler.index_for_timestep(t_tensor) + indices.append(idx) + except IndexError: + # Handle case where no exact match is found + schedule_timesteps = noise_scheduler.timesteps + closest_idx = torch.abs(schedule_timesteps - t_tensor).argmin().item() + indices.append(closest_idx) + timesteps_indices = torch.tensor(indices, device=timesteps.device, dtype=torch.long) + else: + timesteps_indices = timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + return timesteps_indices + +def timesteps_to_indices(timesteps: torch.Tensor, num_train_timesteps: int): + """ + Convert the timesteps into indices by converting the timestep into an long integer. + + Accounts for timestep being within range 0 to 1 and 1 to 1000. + """ + # Check if timesteps are normalized (between 0-1) or absolute (1-1000) + if torch.max(timesteps) <= 1.0: + # Timesteps are normalized, scale them to indices + timesteps_indices = (timesteps * (num_train_timesteps - 1)).round().to(torch.long) + else: + # Timesteps are already in the range of 1 to num_train_timesteps + # We may need to adjust indices if timesteps start from 1 but indices from 0 + timesteps_indices = (timesteps - 1).round().to(torch.long).clamp(0, num_train_timesteps - 1) + + return timesteps_indices + +def get_alphas_cumprod(noise_scheduler) -> Optional[torch.Tensor]: + """ + Get the cumulative product of the alpha values across the timesteps. + + We use the noise scheduler to get the timesteps or use alphas_cumprod. + """ + if hasattr(noise_scheduler, "alphas_cumprod"): + alphas_cumprod = noise_scheduler.alphas_cumprod + elif hasattr(noise_scheduler, "sigmas"): + # Since we don't have alphas_cumprod directly, we can derive it from sigmas + sigmas = noise_scheduler.sigmas + + # In many diffusion models, sigma² = (1-α)/α where α is the cumulative product of alphas + # So we can derive alphas_cumprod from sigmas + alphas_cumprod = 1.0 / (1.0 + sigmas**2) + else: + return None + + return alphas_cumprod def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None