From 9b3d3332a2cb4c50e9daa45f41002ae767a393f6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 18 Mar 2025 15:15:44 -0400 Subject: [PATCH 1/5] Support alpha cumulative product using shifted sigmas for Flux --- flux_train_network.py | 18 ++++++++- library/custom_train_functions.py | 17 +++++--- library/train_util.py | 64 ++++++++++++++++++++++++++++++- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index def44155..712d0bc8 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -21,6 +21,13 @@ from library import ( strategy_flux, train_util, ) +from library.custom_train_functions import ( + prepare_scheduler_for_custom_training, + apply_snr_weight, + scale_v_prediction_loss_like_noise_prediction, + add_v_prediction_like_loss, + apply_debiased_estimation, +) from library.utils import setup_logging setup_logging() @@ -326,6 +333,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) + prepare_scheduler_for_custom_training(noise_scheduler, device) return noise_scheduler def encode_images_to_latents(self, args, vae, images): @@ -450,7 +458,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return model_pred, target, timesteps, weighting - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss: torch.Tensor, args, timesteps, noise_scheduler) -> torch.FloatTensor: + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) return loss def get_sai_model_spec(self, args): diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index ad3e69ff..2d683693 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -6,6 +6,7 @@ import re from torch.types import Number from typing import List, Optional, Union from .utils import setup_logging +from library import train_util setup_logging() import logging @@ -17,7 +18,7 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return - alphas_cumprod = noise_scheduler.alphas_cumprod + alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) alpha = sqrt_alphas_cumprod @@ -66,7 +67,8 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): - snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) @@ -81,9 +83,9 @@ def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: loss = loss * scale return loss - def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info @@ -99,7 +101,12 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size + if not hasattr(noise_scheduler, "all_snr"): + return loss + + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if v_prediction: weight = 1 / (snr_t + 1) diff --git a/library/train_util.py b/library/train_util.py index 1f591c42..e92d4518 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5985,9 +5985,11 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - if not hasattr(noise_scheduler, "alphas_cumprod"): + alphas_cumprod = get_alphas_cumprod(noise_scheduler) + if alphas_cumprod is None: raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") - alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) + timesteps_indices = index_for_timesteps(timesteps, noise_scheduler) + alphas_cumprod = torch.index_select(alphas_cumprod.to(timesteps.device), 0, timesteps_indices) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c result = result.to(timesteps.device) @@ -5998,6 +6000,64 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler return result +def index_for_timesteps(timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor: + if hasattr(noise_scheduler, "index_for_timestep"): + noise_scheduler.timesteps = noise_scheduler.timesteps.to(timesteps.device) + # Convert timesteps to appropriate indices using the scheduler's method + indices = [] + for t in timesteps: + # Make sure t is a tensor with the right device + t_tensor = t if isinstance(t, torch.Tensor) else torch.tensor([t], device=timesteps.device)[0] + try: + # Use the scheduler's method to get the correct index + idx = noise_scheduler.index_for_timestep(t_tensor) + indices.append(idx) + except IndexError: + # Handle case where no exact match is found + schedule_timesteps = noise_scheduler.timesteps + closest_idx = torch.abs(schedule_timesteps - t_tensor).argmin().item() + indices.append(closest_idx) + timesteps_indices = torch.tensor(indices, device=timesteps.device, dtype=torch.long) + else: + timesteps_indices = timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + return timesteps_indices + +def timesteps_to_indices(timesteps: torch.Tensor, num_train_timesteps: int): + """ + Convert the timesteps into indices by converting the timestep into an long integer. + + Accounts for timestep being within range 0 to 1 and 1 to 1000. + """ + # Check if timesteps are normalized (between 0-1) or absolute (1-1000) + if torch.max(timesteps) <= 1.0: + # Timesteps are normalized, scale them to indices + timesteps_indices = (timesteps * (num_train_timesteps - 1)).round().to(torch.long) + else: + # Timesteps are already in the range of 1 to num_train_timesteps + # We may need to adjust indices if timesteps start from 1 but indices from 0 + timesteps_indices = (timesteps - 1).round().to(torch.long).clamp(0, num_train_timesteps - 1) + + return timesteps_indices + +def get_alphas_cumprod(noise_scheduler) -> Optional[torch.Tensor]: + """ + Get the cumulative product of the alpha values across the timesteps. + + We use the noise scheduler to get the timesteps or use alphas_cumprod. + """ + if hasattr(noise_scheduler, "alphas_cumprod"): + alphas_cumprod = noise_scheduler.alphas_cumprod + elif hasattr(noise_scheduler, "sigmas"): + # Since we don't have alphas_cumprod directly, we can derive it from sigmas + sigmas = noise_scheduler.sigmas + + # In many diffusion models, sigma² = (1-α)/α where α is the cumulative product of alphas + # So we can derive alphas_cumprod from sigmas + alphas_cumprod = 1.0 / (1.0 + sigmas**2) + else: + return None + + return alphas_cumprod def conditional_loss( model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None From 8d5a183cc5cf8b00aa27c1cb2013fef46e34e3e4 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 15:40:26 -0400 Subject: [PATCH 2/5] Fix applying image size to post_process_loss --- flux_train_network.py | 20 +- library/custom_train_functions.py | 85 +++++--- library/sd3_train_utils.py | 322 ++++++++++++++++++++++++++---- library/train_util.py | 27 ++- sd3_train_network.py | 2 +- train_network.py | 6 +- 6 files changed, 383 insertions(+), 79 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 712d0bc8..b875f678 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -22,7 +22,7 @@ from library import ( train_util, ) from library.custom_train_functions import ( - prepare_scheduler_for_custom_training, + prepare_scheduler_for_custom_training_flux, apply_snr_weight, scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, @@ -331,9 +331,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): """ def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: - noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) + noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift, use_dynamic_shifting=args.timestep_sampling == "flux_shift") self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) - prepare_scheduler_for_custom_training(noise_scheduler, device) + prepare_scheduler_for_custom_training_flux(noise_scheduler, device) return noise_scheduler def encode_images_to_latents(self, args, vae, images): @@ -458,15 +458,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return model_pred, target, timesteps, weighting - def post_process_loss(self, loss: torch.Tensor, args, timesteps, noise_scheduler) -> torch.FloatTensor: + def post_process_loss(self, loss: torch.Tensor, args, timesteps, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor: + image_size = None + if latents is not None: + image_size = tuple(latents.shape[-2:]) + if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization, image_size) if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler, image_size) if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss, image_size) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization, image_size) return loss def get_sai_model_spec(self, args): diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2d683693..2a657a9f 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -18,6 +18,9 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): if hasattr(noise_scheduler, "all_snr"): return + if hasattr(noise_scheduler.config, "use_dynamic_shifting") and noise_scheduler.config.use_dynamic_shifting is True: + return + alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) @@ -27,6 +30,22 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device): noise_scheduler.all_snr = all_snr.to(device) +def prepare_scheduler_for_custom_training_flux(noise_scheduler, device): + if hasattr(noise_scheduler, "all_snr"): + return + + if hasattr(noise_scheduler.config, "use_dynamic_shifting") and noise_scheduler.config.use_dynamic_shifting is True: + return + + alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler) + if alphas_cumprod is None: + return + + sigma = 1.0 - alphas_cumprod + all_snr = (alphas_cumprod / sigma) + + noise_scheduler.all_snr = all_snr.to(device) + def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): # fix beta: zero terminal SNR @@ -66,9 +85,14 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): noise_scheduler.alphas_cumprod = alphas_cumprod -def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): - timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) - snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) +def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False, image_size=None): + # Get the appropriate SNR values based on timesteps and potentially image size + if hasattr(noise_scheduler, "get_snr_for_timestep"): + snr = noise_scheduler.get_snr_for_timestep(timesteps, image_size) + else: + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) + min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) if v_prediction: snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) @@ -78,14 +102,19 @@ def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_sched return loss -def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): - scale = get_snr_scale(timesteps, noise_scheduler) +def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None): + scale = get_snr_scale(timesteps, noise_scheduler, image_size) loss = loss * scale return loss -def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): - timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) # batch_size +def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None): + # Get SNR values with image_size consideration + if hasattr(noise_scheduler, "get_snr_for_timestep"): + snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size) + else: + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 scale = snr_t / (snr_t + 1) # # show debug info @@ -93,27 +122,37 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): return scale -def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): - scale = get_snr_scale(timesteps, noise_scheduler) +def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor, image_size=None): + scale = get_snr_scale(timesteps, noise_scheduler, image_size) # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") loss = loss + loss / scale * v_pred_like_loss return loss -def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): - if not hasattr(noise_scheduler, "all_snr"): - return loss +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None): + # Check if we have SNR values available + if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): + return loss + + # Get SNR values with image_size consideration + if hasattr(noise_scheduler, "get_snr_for_timestep"): + snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size) + else: + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) + + # Cap the SNR to avoid numerical issues + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) + + # Apply weighting based on prediction type + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) + + loss = weight * loss + return loss - timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) - - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) # batch_size - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - if v_prediction: - weight = 1 / (snr_t + 1) - else: - weight = 1 / torch.sqrt(snr_t) - loss = weight * loss - return loss # TODO train_utilと分散しているのでどちらかに寄せる diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c4079884..1fb34c74 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -28,7 +28,7 @@ import logging logger = logging.getLogger(__name__) -from library import sd3_models, sd3_utils, strategy_base, train_util +from library import sd3_models, sd3_utils, strategy_base, train_util, flux_train_utils def save_models( @@ -598,16 +598,29 @@ def sample_image_inference( # region Diffusers +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import SchedulerMixin -from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import BaseOutput +from diffusers.schedulers.scheduling_utils import SchedulerMixin @dataclass @@ -649,22 +662,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self, num_train_timesteps: int = 1000, shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + shift_terminal: Optional[float] = None, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) sigmas = timesteps / num_train_timesteps - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.timesteps = sigmas * num_train_timesteps self._step_index = None self._begin_index = None + self._shift = shift + self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigma_min = self.sigmas[-1].item() self.sigma_max = self.sigmas[0].item() + @property + def shift(self): + """ + The value used for shifting. + """ + return self._shift + @property def step_index(self): """ @@ -690,6 +730,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ self._begin_index = begin_index + def set_shift(self, shift: float): + self._shift = shift + def scale_noise( self, sample: torch.FloatTensor, @@ -709,10 +752,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: A scaled input sample. """ - if self.step_index is None: - self._init_step_index(timestep) + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) + + if sample.device.type == "mps" and torch.is_floating_point(timestep): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) + timestep = timestep.to(sample.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) - sigma = self.sigmas[self.step_index] sample = sigma * noise + (1.0 - sigma) * sample return sample @@ -720,7 +784,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): def _sigma_to_t(self, sigma): return sigma * self.config.num_train_timesteps - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor: + r""" + Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config + value. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`torch.Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `torch.Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, + ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -730,18 +824,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") + + if sigmas is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps + ) + + sigmas = timesteps / self.config.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) self.num_inference_steps = num_inference_steps - timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + if self.config.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + if self.config.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_exponential_sigmas: + sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) - sigmas = timesteps / self.config.num_train_timesteps - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - timesteps = sigmas * self.config.num_train_timesteps - self.timesteps = timesteps.to(device=device) - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) + else: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.timesteps = timesteps.to(device=device) + self.sigmas = sigmas self._step_index = None self._begin_index = None @@ -807,7 +932,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" @@ -823,30 +952,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] - gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 + prev_sample = sample + (sigma_next - sigma) * model_output - noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator) - - eps = noise * s_noise - sigma_hat = sigma * (gamma + 1) - - if gamma > 0: - sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - # NOTE: "original_sample" should not be an expected prediction_type but is left in for - # backwards compatibility - - # if self.config.prediction_type == "vector_field": - - denoised = sample - model_output * sigma - # 2. Convert to an ODE derivative - derivative = (sample - denoised) / sigma_hat - - dt = self.sigmas[self.step_index + 1] - sigma_hat - - prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) @@ -858,9 +967,146 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential + def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor: + """Constructs an exponential noise schedule.""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps)) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def __len__(self): return self.config.num_train_timesteps + def get_snr_for_timestep(self, timesteps: torch.IntTensor, image_size=None): + """ + Get the signal-to-noise ratio for given timesteps, with consideration for image size. + + Args: + timesteps: Batch of timesteps (already scaled values, timesteps = sigma * 1000.0) + image_size: Tuple of (height, width) or single int representing image dimensions + + Returns: + torch.Tensor: SNR values corresponding to the timesteps + """ + + if not hasattr(self, "all_snr"): + all_sigmas = self.sigmas + assert isinstance(all_sigmas, torch.Tensor), "FlowMatch scheduler sigmas are not tensors" + + # Apply appropriate shifting to sigmas + if image_size is not None and self.config.use_dynamic_shifting: + # Calculate mu based on image dimensions + if isinstance(image_size, (tuple, list)): + h, w = image_size + else: + h = w = image_size + + # Adjust for packed size + h = h // 2 + w = w // 2 + mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)(h * w) + + # Apply time shifting to sigmas + shifted_all_sigmas = self.time_shift(mu, 1.0, all_sigmas) + elif not self.config.use_dynamic_shifting: + # already shifted + shifted_all_sigmas = all_sigmas + else: + shifted_all_sigmas = all_sigmas + + # Calculate SNR based on shifted sigma values + all_snr = ((1.0 - shifted_all_sigmas**2) / (shifted_all_sigmas**2)).to(timesteps.device) + + # If we are using dynamic shifting we can't store all the snr + if not self.config.use_dynamic_shifting: + self.all_snr = all_snr + else: + all_snr = self.all_snr + + + # Convert input timesteps to indices + # Assuming timesteps are in the range [0, 1000] and need to be mapped to indices + timestep_indices = (timesteps / 1000.0 * (len(all_snr.to(timesteps.device)) - 1)).long() + + # Get SNR values for the requested timesteps + requested_snr = all_snr[timestep_indices] + + return requested_snr + + def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) diff --git a/library/train_util.py b/library/train_util.py index e92d4518..9390be6b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -31,6 +31,7 @@ from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy init_ipex() @@ -60,7 +61,7 @@ from diffusers import ( KDPM2AncestralDiscreteScheduler, AutoencoderKL, ) -from library import custom_train_functions, sd3_utils +from library import custom_train_functions, sd3_utils, flux_train_utils from library.original_unet import UNet2DConditionModel from huggingface_hub import hf_hub_download import numpy as np @@ -5976,7 +5977,7 @@ def get_noise_noisy_latents_and_timesteps( return noise, noisy_latents, timesteps -def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: +def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, latents: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): return None @@ -5985,12 +5986,23 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps result = torch.exp(-alpha * timesteps) * args.huber_scale elif args.huber_schedule == "snr": - alphas_cumprod = get_alphas_cumprod(noise_scheduler) + if hasattr(noise_scheduler, "sigmas"): + # Need to adjust the timesteps based on the latent dimensions + if args.timestep_sampling == "flux_shift": + _, _, h, w = latents.shape + mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + alphas_cumprod = get_alphas_cumprod(noise_scheduler, mu) + else: + alphas_cumprod = get_alphas_cumprod(noise_scheduler) + else: + alphas_cumprod = get_alphas_cumprod(noise_scheduler) + if alphas_cumprod is None: raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") timesteps_indices = index_for_timesteps(timesteps, noise_scheduler) alphas_cumprod = torch.index_select(alphas_cumprod.to(timesteps.device), 0, timesteps_indices) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 + result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c result = result.to(timesteps.device) elif args.huber_schedule == "constant": @@ -6039,7 +6051,7 @@ def timesteps_to_indices(timesteps: torch.Tensor, num_train_timesteps: int): return timesteps_indices -def get_alphas_cumprod(noise_scheduler) -> Optional[torch.Tensor]: +def get_alphas_cumprod(noise_scheduler, mu=None) -> Optional[torch.Tensor]: """ Get the cumulative product of the alpha values across the timesteps. @@ -6048,8 +6060,11 @@ def get_alphas_cumprod(noise_scheduler) -> Optional[torch.Tensor]: if hasattr(noise_scheduler, "alphas_cumprod"): alphas_cumprod = noise_scheduler.alphas_cumprod elif hasattr(noise_scheduler, "sigmas"): - # Since we don't have alphas_cumprod directly, we can derive it from sigmas - sigmas = noise_scheduler.sigmas + if noise_scheduler.config.use_dynamic_shifting is True: + sigmas = noise_scheduler.time_shift(mu, 1.0, noise_scheduler.sigmas) + else: + # Since we don't have alphas_cumprod directly, we can derive it from sigmas + sigmas = noise_scheduler.sigmas # In many diffusion models, sigma² = (1-α)/α where α is the cumulative product of alphas # So we can derive alphas_cumprod from sigmas diff --git a/sd3_train_network.py b/sd3_train_network.py index cdb7aa4e..01265733 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -391,7 +391,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): return model_pred, target, timesteps, weighting - def post_process_loss(self, loss, args, timesteps, noise_scheduler): + def post_process_loss(self, loss, args, timesteps, noise_scheduler, latents): return loss def get_sai_model_spec(self, args): diff --git a/train_network.py b/train_network.py index 2d279b3b..5a4d83f0 100644 --- a/train_network.py +++ b/train_network.py @@ -316,7 +316,7 @@ class NetworkTrainer: return noise_pred, target, timesteps, None - def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: + def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -442,7 +442,7 @@ class NetworkTrainer: is_train=is_train, ) - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) if weighting is not None: loss = loss * weighting @@ -453,7 +453,7 @@ class NetworkTrainer: loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights - loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) + loss = self.post_process_loss(loss, args, timesteps, noise_scheduler, latents) return loss.mean() From 3ffd3b84a55a4c5227b62b8226e97fb40000a249 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 16:58:49 -0400 Subject: [PATCH 3/5] Add custom train functions test for loss modifications --- library/custom_train_functions.py | 47 ++-- tests/library/test_custom_train_functions.py | 227 +++++++++++++++++++ 2 files changed, 252 insertions(+), 22 deletions(-) create mode 100644 tests/library/test_custom_train_functions.py diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a657a9f..7eaefc32 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -87,7 +87,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False, image_size=None): # Get the appropriate SNR values based on timesteps and potentially image size - if hasattr(noise_scheduler, "get_snr_for_timestep"): + if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): snr = noise_scheduler.get_snr_for_timestep(timesteps, image_size) else: timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) @@ -109,7 +109,7 @@ def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None): # Get SNR values with image_size consideration - if hasattr(noise_scheduler, "get_snr_for_timestep"): + if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size) else: timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) @@ -131,27 +131,30 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None): # Check if we have SNR values available - if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): - return loss + if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): + return loss - # Get SNR values with image_size consideration - if hasattr(noise_scheduler, "get_snr_for_timestep"): - snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size) - else: - timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) - - # Cap the SNR to avoid numerical issues - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) - - # Apply weighting based on prediction type - if v_prediction: - weight = 1 / (snr_t + 1) - else: - weight = 1 / torch.sqrt(snr_t) - - loss = weight * loss - return loss + if not callable(noise_scheduler.get_snr_for_timestep): + return loss + + # Get SNR values with image_size consideration + if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): + snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size) + else: + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) + + # Cap the SNR to avoid numerical issues + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) + + # Apply weighting based on prediction type + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) + + loss = weight * loss + return loss diff --git a/tests/library/test_custom_train_functions.py b/tests/library/test_custom_train_functions.py new file mode 100644 index 00000000..8bb4f6f9 --- /dev/null +++ b/tests/library/test_custom_train_functions.py @@ -0,0 +1,227 @@ +import pytest +import torch +import numpy as np +from unittest.mock import MagicMock, patch + +# Import the functions we're testing +from library.custom_train_functions import ( + apply_snr_weight, + scale_v_prediction_loss_like_noise_prediction, + get_snr_scale, + add_v_prediction_like_loss, + apply_debiased_estimation, +) + + +@pytest.fixture +def loss(): + return torch.ones(2, 1) + + +@pytest.fixture +def timesteps(): + return torch.tensor([[200, 200]], dtype=torch.int32) + + +@pytest.fixture +def noise_scheduler(): + scheduler = MagicMock() + scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([10.0, 5.0])) + scheduler.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + return scheduler + + +# Tests for apply_snr_weight +def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler): + image_size = 64 + gamma = 5.0 + + result = apply_snr_weight( + loss, + timesteps, + noise_scheduler, + gamma, + v_prediction=False, + image_size=image_size, + ) + + expected_result = torch.tensor([[0.5, 1.0]]) + + assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) + + +def test_apply_snr_weight_with_all_snr(loss, timesteps): + gamma = 5.0 + + # Modify the mock to not use get_snr_for_timestep + mock_noise_scheduler_no_method = MagicMock() + mock_noise_scheduler_no_method.get_snr_for_timestep = None + mock_noise_scheduler_no_method.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0]) + + result = apply_snr_weight(loss, timesteps, mock_noise_scheduler_no_method, gamma, v_prediction=False) + + expected_result = torch.tensor([1.0, 1.0]) + assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) + + +def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler): + gamma = 5.0 + + result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True) + + expected_result = torch.tensor([[0.4545, 0.8333], [0.4545, 0.8333]]) + + assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) + + +# Tests for scale_v_prediction_loss_like_noise_prediction +def test_scale_v_prediction_loss(loss, timesteps, noise_scheduler): + with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale: + mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8]) + + result = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + + mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None) + + # Apply broadcasting for multiplication + scale = torch.tensor([[0.9, 0.8], [0.9, 0.8]]) + expected_result = loss * scale + assert torch.allclose(result, expected_result) + + +# Tests for get_snr_scale +def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler): + image_size = 64 + + result = get_snr_scale(timesteps, noise_scheduler, image_size) + + # Verify the method was called correctly + noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size) + + # Calculate expected values (snr / (snr + 1)) + snr = torch.tensor([10.0, 5.0]) + expected_scale = snr / (snr + 1) + + assert torch.allclose(result, expected_scale) + + +def test_get_snr_scale_with_all_snr(timesteps): + # Create a scheduler that only has all_snr + mock_scheduler_all_snr = MagicMock() + mock_scheduler_all_snr.get_snr_for_timestep = None + mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + + result = get_snr_scale(timesteps, mock_scheduler_all_snr) + + expected_scale = torch.tensor([[0.9524, 0.9524]]) + + assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4) + + +def test_get_snr_scale_with_large_snr(timesteps, noise_scheduler): + # Set up the mock with a very large SNR value + noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0]) + + result = get_snr_scale(timesteps, noise_scheduler) + + expected_scale = torch.tensor([0.9990, 0.8333]) + + assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4) + + +# Tests for add_v_prediction_like_loss +def test_add_v_prediction_like_loss(loss, timesteps, noise_scheduler): + v_pred_like_loss = torch.tensor([0.3, 0.2]).view(2, 1) + + with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale: + mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8]) + + result = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss) + + mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None) + + expected_result = torch.tensor([[1.3333, 1.3750], [1.2222, 1.2500]]) + assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) + + +# Tests for apply_debiased_estimation +def test_apply_debiased_estimation_no_snr(loss, timesteps): + # Create a scheduler without SNR methods + scheduler_without_snr = MagicMock() + # Need to explicitly set attribute to None instead of deleting + scheduler_without_snr.get_snr_for_timestep = None + + result = apply_debiased_estimation(loss, timesteps, scheduler_without_snr) + + # When no SNR methods are available, the function should return the loss unchanged + assert torch.equal(result, loss) + + +def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_scheduler): + # Test with v_prediction=False + result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False) + + expected_result_no_v = torch.tensor([[0.3162, 0.4472], [0.3162, 0.4472]]) + + assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4) + + # Test with v_prediction=True + result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True) + + expected_result_v = torch.tensor([[0.0909, 0.1667], [0.0909, 0.1667]]) + + assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4) + + +def test_apply_debiased_estimation_with_all_snr(loss, timesteps): + # Create a scheduler that only has all_snr + mock_scheduler_all_snr = MagicMock() + mock_scheduler_all_snr.get_snr_for_timestep = None + mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + + result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False) + + expected_result = torch.tensor([[1.0, 1.0]]) + + assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) + + +def test_apply_debiased_estimation_with_large_snr(loss, timesteps, noise_scheduler): + # Set up the mock with a very large SNR value + noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0]) + + result = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False) + + expected_result = torch.tensor([[0.0316, 0.4472], [0.0316, 0.4472]]) + + assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) + + +# Additional edge cases +def test_empty_tensors(noise_scheduler): + # Test with empty tensors + loss = torch.tensor([], dtype=torch.float32) + timesteps = torch.tensor([], dtype=torch.int32) + + assert isinstance(timesteps, torch.IntTensor) + + noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([], dtype=torch.float32) + + result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma=5.0) + + assert result.shape == loss.shape + assert result.dtype == loss.dtype + + +def test_different_device_compatibility(loss, timesteps, noise_scheduler): + gamma = 5.0 + device = torch.device("cpu") + + # For a real device test, we need to create actual tensors on devices + loss_on_device = loss.to(device) + + # Mock the SNR tensor that would be returned with proper device handling + snr_tensor = torch.tensor([0.204, 0.294], device=device) + noise_scheduler.get_snr_for_timestep.return_value = snr_tensor + + result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma) From 2cadeaff0ad9a21b2e0221ce91b48b29dcf65a2b Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 20 Mar 2025 17:15:37 -0400 Subject: [PATCH 4/5] Fix typical snr values to be in appropriate range --- pytest.ini | 1 + tests/library/test_custom_train_functions.py | 26 +++++++++----------- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/tests/library/test_custom_train_functions.py b/tests/library/test_custom_train_functions.py index 8bb4f6f9..c31f7d9d 100644 --- a/tests/library/test_custom_train_functions.py +++ b/tests/library/test_custom_train_functions.py @@ -20,14 +20,14 @@ def loss(): @pytest.fixture def timesteps(): - return torch.tensor([[200, 200]], dtype=torch.int32) + return torch.tensor([[200, 600]], dtype=torch.int32) @pytest.fixture def noise_scheduler(): scheduler = MagicMock() - scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([10.0, 5.0])) - scheduler.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39])) + scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0]) return scheduler @@ -45,7 +45,7 @@ def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler): image_size=image_size, ) - expected_result = torch.tensor([[0.5, 1.0]]) + expected_result = torch.tensor([[1.0, 1.0]]) assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) @@ -69,7 +69,7 @@ def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler): result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True) - expected_result = torch.tensor([[0.4545, 0.8333], [0.4545, 0.8333]]) + expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]]) assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4) @@ -98,22 +98,20 @@ def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler): # Verify the method was called correctly noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size) - # Calculate expected values (snr / (snr + 1)) - snr = torch.tensor([10.0, 5.0]) - expected_scale = snr / (snr + 1) + expected_scale = torch.tensor([0.2272, 0.2806]) - assert torch.allclose(result, expected_scale) + assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4) def test_get_snr_scale_with_all_snr(timesteps): # Create a scheduler that only has all_snr mock_scheduler_all_snr = MagicMock() mock_scheduler_all_snr.get_snr_for_timestep = None - mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0]) result = get_snr_scale(timesteps, mock_scheduler_all_snr) - expected_scale = torch.tensor([[0.9524, 0.9524]]) + expected_scale = torch.tensor([[0.5000, 0.5000]]) assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4) @@ -161,14 +159,14 @@ def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_sc # Test with v_prediction=False result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False) - expected_result_no_v = torch.tensor([[0.3162, 0.4472], [0.3162, 0.4472]]) + expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]]) assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4) # Test with v_prediction=True result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True) - expected_result_v = torch.tensor([[0.0909, 0.1667], [0.0909, 0.1667]]) + expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]]) assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4) @@ -177,7 +175,7 @@ def test_apply_debiased_estimation_with_all_snr(loss, timesteps): # Create a scheduler that only has all_snr mock_scheduler_all_snr = MagicMock() mock_scheduler_all_snr.get_snr_for_timestep = None - mock_scheduler_all_snr.all_snr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 20.0]) + mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0]) result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False) From 4c8ebf7293fb1ed5753748156823965dc835ba79 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 16 Jun 2025 17:22:32 -0400 Subject: [PATCH 5/5] Add more tests --- tests/library/test_custom_train_functions.py | 39 +++++++++++ tests/library/test_flux_train_utils.py | 68 ++++++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/tests/library/test_custom_train_functions.py b/tests/library/test_custom_train_functions.py index c31f7d9d..97c044a9 100644 --- a/tests/library/test_custom_train_functions.py +++ b/tests/library/test_custom_train_functions.py @@ -223,3 +223,42 @@ def test_different_device_compatibility(loss, timesteps, noise_scheduler): noise_scheduler.get_snr_for_timestep.return_value = snr_tensor result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma) + +# Additional tests for new functionality +def test_apply_snr_weight_with_image_size(loss, timesteps, noise_scheduler): + """Test SNR weight application with image size consideration""" + gamma = 5.0 + image_sizes = [None, 64, (256, 256)] + + for image_size in image_sizes: + result = apply_snr_weight( + loss, + timesteps, + noise_scheduler, + gamma, + v_prediction=False, + image_size=image_size + ) + + # Allow for broadcasting + assert result.shape[0] == loss.shape[0] + assert result.dtype == loss.dtype + +def test_apply_debiased_estimation_variations(loss, timesteps, noise_scheduler): + """Test debiased estimation with different image sizes and prediction types""" + image_sizes = [None, 64, (256, 256)] + prediction_types = [True, False] + + for image_size in image_sizes: + for v_prediction in prediction_types: + result = apply_debiased_estimation( + loss, + timesteps, + noise_scheduler, + v_prediction=v_prediction, + image_size=image_size + ) + + # Allow for broadcasting + assert result.shape[0] == loss.shape[0] + assert result.dtype == loss.dtype diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 2ad7ce4e..e5baba00 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -1,6 +1,8 @@ import pytest import torch +import math from unittest.mock import MagicMock, patch +from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler from library.flux_train_utils import ( get_noisy_model_input_and_timesteps, ) @@ -218,3 +220,69 @@ def test_different_timestep_count(args, device): assert timesteps.shape == (2,) # Check that timesteps are within the proper range assert torch.all(timesteps < 500) + +# New tests for dynamic timestep shifting +def test_dynamic_timestep_shifting(device): + """Test the dynamic timestep shifting functionality""" + # Create a scheduler with dynamic shifting enabled + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=1.0, + use_dynamic_shifting=True + ) + + # Test different image sizes + test_sizes = [ + (64, 64), # Small image + (256, 256), # Medium image + (512, 512), # Large image + (1024, 1024) # Very large image + ] + + for image_size in test_sizes: + # Simulate setting timesteps for inference + mu = math.log(1 + (image_size[0] * image_size[1]) / (256 * 256)) + scheduler.set_timesteps(num_inference_steps=50, mu=mu) + + # Check that sigmas have been dynamically shifted + assert len(scheduler.sigmas) == 51 # num_inference_steps + 1 + assert scheduler.sigmas[0] <= 1.0 # Maximum sigma should be <= 1 + assert scheduler.sigmas[-1] == 0.0 # Last sigma should always be 0 + +def test_sigma_generation_methods(): + """Test different sigma generation methods""" + # Test Karras sigmas + scheduler_karras = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + use_karras_sigmas=True + ) + scheduler_karras.set_timesteps(num_inference_steps=50) + assert len(scheduler_karras.sigmas) == 51 + + # Test Exponential sigmas + scheduler_exp = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + use_exponential_sigmas=True + ) + scheduler_exp.set_timesteps(num_inference_steps=50) + assert len(scheduler_exp.sigmas) == 51 + +def test_snr_calculation(): + """Test the SNR calculation method""" + scheduler = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=1.0 + ) + + # Prepare test timesteps + timesteps = torch.tensor([200, 600], dtype=torch.int32) + + # Test with different image sizes + test_sizes = [None, 64, (256, 256)] + + for image_size in test_sizes: + snr_values = scheduler.get_snr_for_timestep(timesteps, image_size) + + # Check basic properties + assert snr_values.shape == torch.Size([2]) + assert torch.all(snr_values >= 0) # SNR should always be non-negative