This commit is contained in:
Dave Lage
2025-09-28 00:23:44 +05:30
committed by GitHub
8 changed files with 782 additions and 60 deletions

View File

@@ -21,6 +21,13 @@ from library import (
strategy_flux, strategy_flux,
train_util, train_util,
) )
from library.custom_train_functions import (
prepare_scheduler_for_custom_training_flux,
apply_snr_weight,
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
@@ -299,8 +306,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
) )
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any: def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift) noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift, use_dynamic_shifting=args.timestep_sampling == "flux_shift")
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
prepare_scheduler_for_custom_training_flux(noise_scheduler, device)
return noise_scheduler return noise_scheduler
def encode_images_to_latents(self, args, vae, images): def encode_images_to_latents(self, args, vae, images):
@@ -433,7 +441,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return model_pred, target, timesteps, weighting return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler): def post_process_loss(self, loss: torch.Tensor, args, timesteps, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor:
image_size = None
if latents is not None:
image_size = tuple(latents.shape[-2:])
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization, image_size)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler, image_size)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss, image_size)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization, image_size)
return loss return loss
def get_sai_model_spec(self, args): def get_sai_model_spec(self, args):

View File

@@ -6,6 +6,7 @@ import re
from torch.types import Number from torch.types import Number
from typing import List, Optional, Union from typing import List, Optional, Union
from .utils import setup_logging from .utils import setup_logging
from library import train_util
setup_logging() setup_logging()
import logging import logging
@@ -17,7 +18,10 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"): if hasattr(noise_scheduler, "all_snr"):
return return
alphas_cumprod = noise_scheduler.alphas_cumprod if hasattr(noise_scheduler.config, "use_dynamic_shifting") and noise_scheduler.config.use_dynamic_shifting is True:
return
alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod alpha = sqrt_alphas_cumprod
@@ -26,6 +30,22 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
noise_scheduler.all_snr = all_snr.to(device) noise_scheduler.all_snr = all_snr.to(device)
def prepare_scheduler_for_custom_training_flux(noise_scheduler, device):
if hasattr(noise_scheduler, "all_snr"):
return
if hasattr(noise_scheduler.config, "use_dynamic_shifting") and noise_scheduler.config.use_dynamic_shifting is True:
return
alphas_cumprod = train_util.get_alphas_cumprod(noise_scheduler)
if alphas_cumprod is None:
return
sigma = 1.0 - alphas_cumprod
all_snr = (alphas_cumprod / sigma)
noise_scheduler.all_snr = all_snr.to(device)
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler): def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
# fix beta: zero terminal SNR # fix beta: zero terminal SNR
@@ -65,8 +85,14 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
noise_scheduler.alphas_cumprod = alphas_cumprod noise_scheduler.alphas_cumprod = alphas_cumprod
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False): def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False, image_size=None):
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # Get the appropriate SNR values based on timesteps and potentially image size
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma)) min_snr_gamma = torch.minimum(snr, torch.full_like(snr, gamma))
if v_prediction: if v_prediction:
snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device) snr_weight = torch.div(min_snr_gamma, snr + 1).float().to(loss.device)
@@ -76,14 +102,19 @@ def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_sched
return loss return loss
def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler): def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None):
scale = get_snr_scale(timesteps, noise_scheduler) scale = get_snr_scale(timesteps, noise_scheduler, image_size)
loss = loss * scale loss = loss * scale
return loss return loss
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None):
# Get SNR values with image_size consideration
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
scale = snr_t / (snr_t + 1) scale = snr_t / (snr_t + 1)
# # show debug info # # show debug info
@@ -91,24 +122,42 @@ def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler):
return scale return scale
def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor): def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_pred_like_loss: torch.Tensor, image_size=None):
scale = get_snr_scale(timesteps, noise_scheduler) scale = get_snr_scale(timesteps, noise_scheduler, image_size)
# logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}") # logger.info(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
loss = loss + loss / scale * v_pred_like_loss loss = loss + loss / scale * v_pred_like_loss
return loss return loss
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size # Check if we have SNR values available
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
return loss
if not callable(noise_scheduler.get_snr_for_timestep):
return loss
# Get SNR values with image_size consideration
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
# Cap the SNR to avoid numerical issues
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
# Apply weighting based on prediction type
if v_prediction: if v_prediction:
weight = 1 / (snr_t + 1) weight = 1 / (snr_t + 1)
else: else:
weight = 1 / torch.sqrt(snr_t) weight = 1 / torch.sqrt(snr_t)
loss = weight * loss loss = weight * loss
return loss return loss
# TODO train_utilと分散しているのでどちらかに寄せる # TODO train_utilと分散しているのでどちらかに寄せる

View File

@@ -28,7 +28,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from library import sd3_models, sd3_utils, strategy_base, train_util from library import sd3_models, sd3_utils, strategy_base, train_util, flux_train_utils
def save_models( def save_models(
@@ -598,16 +598,29 @@ def sample_image_inference(
# region Diffusers # region Diffusers
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import BaseOutput from diffusers.utils import BaseOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclass @dataclass
@@ -649,22 +662,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self, self,
num_train_timesteps: int = 1000, num_train_timesteps: int = 1000,
shift: float = 1.0, shift: float = 1.0,
use_dynamic_shifting=False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
invert_sigmas: bool = False,
shift_terminal: Optional[float] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
): ):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = sigmas * num_train_timesteps self.timesteps = sigmas * num_train_timesteps
self._step_index = None self._step_index = None
self._begin_index = None self._begin_index = None
self._shift = shift
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item() self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item() self.sigma_max = self.sigmas[0].item()
@property
def shift(self):
"""
The value used for shifting.
"""
return self._shift
@property @property
def step_index(self): def step_index(self):
""" """
@@ -690,6 +730,9 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
self._begin_index = begin_index self._begin_index = begin_index
def set_shift(self, shift: float):
self._shift = shift
def scale_noise( def scale_noise(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
@@ -709,10 +752,31 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`: `torch.FloatTensor`:
A scaled input sample. A scaled input sample.
""" """
if self.step_index is None: # Make sure sigmas and timesteps have the same device and dtype as original_samples
self._init_step_index(timestep) sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
if sample.device.type == "mps" and torch.is_floating_point(timestep):
# mps does not support float64
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
timestep = timestep.to(sample.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(sample.device)
timestep = timestep.to(sample.device)
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timestep.shape[0]
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timestep.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(sample.shape):
sigma = sigma.unsqueeze(-1)
sigma = self.sigmas[self.step_index]
sample = sigma * noise + (1.0 - sigma) * sample sample = sigma * noise + (1.0 - sigma) * sample
return sample return sample
@@ -720,7 +784,37 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
def _sigma_to_t(self, sigma): def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps return sigma * self.config.num_train_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
Reference:
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
stretched_t = 1 - (one_minus_z / scale_factor)
return stretched_t
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
""" """
Sets the discrete timesteps used for the diffusion chain (to be run before inference). Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -730,18 +824,49 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*): device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
""" """
if self.config.use_dynamic_shifting and mu is None:
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
if sigmas is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
sigmas = timesteps / self.config.num_train_timesteps
else:
sigmas = np.array(sigmas).astype(np.float32)
num_inference_steps = len(sigmas)
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps) if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self._step_index = None self._step_index = None
self._begin_index = None self._begin_index = None
@@ -807,7 +932,11 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
returned, otherwise a tuple is returned where the first element is the sample tensor. returned, otherwise a tuple is returned where the first element is the sample tensor.
""" """
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor): if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError( raise ValueError(
( (
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -823,30 +952,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample.to(torch.float32) sample = sample.to(torch.float32)
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 prev_sample = sample + (sigma_next - sigma) * model_output
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
# backwards compatibility
# if self.config.prediction_type == "vector_field":
denoised = sample - model_output * sigma
# 2. Convert to an ODE derivative
derivative = (sample - denoised) / sigma_hat
dt = self.sigmas[self.step_index + 1] - sigma_hat
prev_sample = sample + derivative * dt
# Cast sample back to model compatible dtype # Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype) prev_sample = prev_sample.to(model_output.dtype)
@@ -858,9 +967,146 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
def _convert_to_beta(
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
) -> torch.Tensor:
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = np.array(
[
sigma_min + (ppf * (sigma_max - sigma_min))
for ppf in [
scipy.stats.beta.ppf(timestep, alpha, beta)
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
]
]
)
return sigmas
def __len__(self): def __len__(self):
return self.config.num_train_timesteps return self.config.num_train_timesteps
def get_snr_for_timestep(self, timesteps: torch.IntTensor, image_size=None):
"""
Get the signal-to-noise ratio for given timesteps, with consideration for image size.
Args:
timesteps: Batch of timesteps (already scaled values, timesteps = sigma * 1000.0)
image_size: Tuple of (height, width) or single int representing image dimensions
Returns:
torch.Tensor: SNR values corresponding to the timesteps
"""
if not hasattr(self, "all_snr"):
all_sigmas = self.sigmas
assert isinstance(all_sigmas, torch.Tensor), "FlowMatch scheduler sigmas are not tensors"
# Apply appropriate shifting to sigmas
if image_size is not None and self.config.use_dynamic_shifting:
# Calculate mu based on image dimensions
if isinstance(image_size, (tuple, list)):
h, w = image_size
else:
h = w = image_size
# Adjust for packed size
h = h // 2
w = w // 2
mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)(h * w)
# Apply time shifting to sigmas
shifted_all_sigmas = self.time_shift(mu, 1.0, all_sigmas)
elif not self.config.use_dynamic_shifting:
# already shifted
shifted_all_sigmas = all_sigmas
else:
shifted_all_sigmas = all_sigmas
# Calculate SNR based on shifted sigma values
all_snr = ((1.0 - shifted_all_sigmas**2) / (shifted_all_sigmas**2)).to(timesteps.device)
# If we are using dynamic shifting we can't store all the snr
if not self.config.use_dynamic_shifting:
self.all_snr = all_snr
else:
all_snr = self.all_snr
# Convert input timesteps to indices
# Assuming timesteps are in the range [0, 1000] and need to be mapped to indices
timestep_indices = (timesteps / 1000.0 * (len(all_snr.to(timesteps.device)) - 1)).long()
# Get SNR values for the requested timesteps
requested_snr = all_snr[timestep_indices]
return requested_snr
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)

View File

@@ -31,6 +31,7 @@ from packaging.version import Version
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
init_ipex() init_ipex()
@@ -60,7 +61,7 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
AutoencoderKL, AutoencoderKL,
) )
from library import custom_train_functions, sd3_utils from library import custom_train_functions, sd3_utils, flux_train_utils
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
import numpy as np import numpy as np
@@ -6145,7 +6146,7 @@ def get_noise_noisy_latents_and_timesteps(
return noise, noisy_latents, timesteps return noise, noisy_latents, timesteps
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]: def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, latents: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"): if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
return None return None
@@ -6154,10 +6155,23 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr": elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"): if hasattr(noise_scheduler, "sigmas"):
# Need to adjust the timesteps based on the latent dimensions
if args.timestep_sampling == "flux_shift":
_, _, h, w = latents.shape
mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
alphas_cumprod = get_alphas_cumprod(noise_scheduler, mu)
else:
alphas_cumprod = get_alphas_cumprod(noise_scheduler)
else:
alphas_cumprod = get_alphas_cumprod(noise_scheduler)
if alphas_cumprod is None:
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.") raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu()) timesteps_indices = index_for_timesteps(timesteps, noise_scheduler)
alphas_cumprod = torch.index_select(alphas_cumprod.to(timesteps.device), 0, timesteps_indices)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device) result = result.to(timesteps.device)
elif args.huber_schedule == "constant": elif args.huber_schedule == "constant":
@@ -6167,6 +6181,67 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
return result return result
def index_for_timesteps(timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
if hasattr(noise_scheduler, "index_for_timestep"):
noise_scheduler.timesteps = noise_scheduler.timesteps.to(timesteps.device)
# Convert timesteps to appropriate indices using the scheduler's method
indices = []
for t in timesteps:
# Make sure t is a tensor with the right device
t_tensor = t if isinstance(t, torch.Tensor) else torch.tensor([t], device=timesteps.device)[0]
try:
# Use the scheduler's method to get the correct index
idx = noise_scheduler.index_for_timestep(t_tensor)
indices.append(idx)
except IndexError:
# Handle case where no exact match is found
schedule_timesteps = noise_scheduler.timesteps
closest_idx = torch.abs(schedule_timesteps - t_tensor).argmin().item()
indices.append(closest_idx)
timesteps_indices = torch.tensor(indices, device=timesteps.device, dtype=torch.long)
else:
timesteps_indices = timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
return timesteps_indices
def timesteps_to_indices(timesteps: torch.Tensor, num_train_timesteps: int):
"""
Convert the timesteps into indices by converting the timestep into an long integer.
Accounts for timestep being within range 0 to 1 and 1 to 1000.
"""
# Check if timesteps are normalized (between 0-1) or absolute (1-1000)
if torch.max(timesteps) <= 1.0:
# Timesteps are normalized, scale them to indices
timesteps_indices = (timesteps * (num_train_timesteps - 1)).round().to(torch.long)
else:
# Timesteps are already in the range of 1 to num_train_timesteps
# We may need to adjust indices if timesteps start from 1 but indices from 0
timesteps_indices = (timesteps - 1).round().to(torch.long).clamp(0, num_train_timesteps - 1)
return timesteps_indices
def get_alphas_cumprod(noise_scheduler, mu=None) -> Optional[torch.Tensor]:
"""
Get the cumulative product of the alpha values across the timesteps.
We use the noise scheduler to get the timesteps or use alphas_cumprod.
"""
if hasattr(noise_scheduler, "alphas_cumprod"):
alphas_cumprod = noise_scheduler.alphas_cumprod
elif hasattr(noise_scheduler, "sigmas"):
if noise_scheduler.config.use_dynamic_shifting is True:
sigmas = noise_scheduler.time_shift(mu, 1.0, noise_scheduler.sigmas)
else:
# Since we don't have alphas_cumprod directly, we can derive it from sigmas
sigmas = noise_scheduler.sigmas
# In many diffusion models, sigma² = (1-α)/α where α is the cumulative product of alphas
# So we can derive alphas_cumprod from sigmas
alphas_cumprod = 1.0 / (1.0 + sigmas**2)
else:
return None
return alphas_cumprod
def conditional_loss( def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None

View File

@@ -392,7 +392,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
return model_pred, target, timesteps, weighting return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler): def post_process_loss(self, loss, args, timesteps, noise_scheduler, latents):
return loss return loss
def get_sai_model_spec(self, args): def get_sai_model_spec(self, args):

View File

@@ -0,0 +1,264 @@
import pytest
import torch
import numpy as np
from unittest.mock import MagicMock, patch
# Import the functions we're testing
from library.custom_train_functions import (
apply_snr_weight,
scale_v_prediction_loss_like_noise_prediction,
get_snr_scale,
add_v_prediction_like_loss,
apply_debiased_estimation,
)
@pytest.fixture
def loss():
return torch.ones(2, 1)
@pytest.fixture
def timesteps():
return torch.tensor([[200, 600]], dtype=torch.int32)
@pytest.fixture
def noise_scheduler():
scheduler = MagicMock()
scheduler.get_snr_for_timestep = MagicMock(return_value=torch.tensor([0.294, 0.39]))
scheduler.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
return scheduler
# Tests for apply_snr_weight
def test_apply_snr_weight_with_get_snr_method(loss, timesteps, noise_scheduler):
image_size = 64
gamma = 5.0
result = apply_snr_weight(
loss,
timesteps,
noise_scheduler,
gamma,
v_prediction=False,
image_size=image_size,
)
expected_result = torch.tensor([[1.0, 1.0]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
def test_apply_snr_weight_with_all_snr(loss, timesteps):
gamma = 5.0
# Modify the mock to not use get_snr_for_timestep
mock_noise_scheduler_no_method = MagicMock()
mock_noise_scheduler_no_method.get_snr_for_timestep = None
mock_noise_scheduler_no_method.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
result = apply_snr_weight(loss, timesteps, mock_noise_scheduler_no_method, gamma, v_prediction=False)
expected_result = torch.tensor([1.0, 1.0])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
def test_apply_snr_weight_with_v_prediction(loss, timesteps, noise_scheduler):
gamma = 5.0
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma, v_prediction=True)
expected_result = torch.tensor([[0.2272, 0.2806], [0.2272, 0.2806]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
# Tests for scale_v_prediction_loss_like_noise_prediction
def test_scale_v_prediction_loss(loss, timesteps, noise_scheduler):
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
result = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
# Apply broadcasting for multiplication
scale = torch.tensor([[0.9, 0.8], [0.9, 0.8]])
expected_result = loss * scale
assert torch.allclose(result, expected_result)
# Tests for get_snr_scale
def test_get_snr_scale_with_get_snr_method(timesteps, noise_scheduler):
image_size = 64
result = get_snr_scale(timesteps, noise_scheduler, image_size)
# Verify the method was called correctly
noise_scheduler.get_snr_for_timestep.assert_called_once_with(timesteps, image_size)
expected_scale = torch.tensor([0.2272, 0.2806])
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
def test_get_snr_scale_with_all_snr(timesteps):
# Create a scheduler that only has all_snr
mock_scheduler_all_snr = MagicMock()
mock_scheduler_all_snr.get_snr_for_timestep = None
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 0.75, 1.0])
result = get_snr_scale(timesteps, mock_scheduler_all_snr)
expected_scale = torch.tensor([[0.5000, 0.5000]])
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
def test_get_snr_scale_with_large_snr(timesteps, noise_scheduler):
# Set up the mock with a very large SNR value
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
result = get_snr_scale(timesteps, noise_scheduler)
expected_scale = torch.tensor([0.9990, 0.8333])
assert torch.allclose(result, expected_scale, rtol=1e-4, atol=1e-4)
# Tests for add_v_prediction_like_loss
def test_add_v_prediction_like_loss(loss, timesteps, noise_scheduler):
v_pred_like_loss = torch.tensor([0.3, 0.2]).view(2, 1)
with patch("library.custom_train_functions.get_snr_scale") as mock_get_snr_scale:
mock_get_snr_scale.return_value = torch.tensor([0.9, 0.8])
result = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss)
mock_get_snr_scale.assert_called_once_with(timesteps, noise_scheduler, None)
expected_result = torch.tensor([[1.3333, 1.3750], [1.2222, 1.2500]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
# Tests for apply_debiased_estimation
def test_apply_debiased_estimation_no_snr(loss, timesteps):
# Create a scheduler without SNR methods
scheduler_without_snr = MagicMock()
# Need to explicitly set attribute to None instead of deleting
scheduler_without_snr.get_snr_for_timestep = None
result = apply_debiased_estimation(loss, timesteps, scheduler_without_snr)
# When no SNR methods are available, the function should return the loss unchanged
assert torch.equal(result, loss)
def test_apply_debiased_estimation_with_get_snr_method(loss, timesteps, noise_scheduler):
# Test with v_prediction=False
result_no_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
expected_result_no_v = torch.tensor([[1.8443, 1.6013], [1.8443, 1.6013]])
assert torch.allclose(result_no_v, expected_result_no_v, rtol=1e-4, atol=1e-4)
# Test with v_prediction=True
result_v = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=True)
expected_result_v = torch.tensor([[0.7728, 0.7194], [0.7728, 0.7194]])
assert torch.allclose(result_v, expected_result_v, rtol=1e-4, atol=1e-4)
def test_apply_debiased_estimation_with_all_snr(loss, timesteps):
# Create a scheduler that only has all_snr
mock_scheduler_all_snr = MagicMock()
mock_scheduler_all_snr.get_snr_for_timestep = None
mock_scheduler_all_snr.all_snr = torch.tensor([0.05, 0.1, 0.15, 0.2, 0.25, 0.5, 1.0])
result = apply_debiased_estimation(loss, timesteps, mock_scheduler_all_snr, v_prediction=False)
expected_result = torch.tensor([[1.0, 1.0]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
def test_apply_debiased_estimation_with_large_snr(loss, timesteps, noise_scheduler):
# Set up the mock with a very large SNR value
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([2000.0, 5.0])
result = apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False)
expected_result = torch.tensor([[0.0316, 0.4472], [0.0316, 0.4472]])
assert torch.allclose(result, expected_result, rtol=1e-4, atol=1e-4)
# Additional edge cases
def test_empty_tensors(noise_scheduler):
# Test with empty tensors
loss = torch.tensor([], dtype=torch.float32)
timesteps = torch.tensor([], dtype=torch.int32)
assert isinstance(timesteps, torch.IntTensor)
noise_scheduler.get_snr_for_timestep.return_value = torch.tensor([], dtype=torch.float32)
result = apply_snr_weight(loss, timesteps, noise_scheduler, gamma=5.0)
assert result.shape == loss.shape
assert result.dtype == loss.dtype
def test_different_device_compatibility(loss, timesteps, noise_scheduler):
gamma = 5.0
device = torch.device("cpu")
# For a real device test, we need to create actual tensors on devices
loss_on_device = loss.to(device)
# Mock the SNR tensor that would be returned with proper device handling
snr_tensor = torch.tensor([0.204, 0.294], device=device)
noise_scheduler.get_snr_for_timestep.return_value = snr_tensor
result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma)
# Additional tests for new functionality
def test_apply_snr_weight_with_image_size(loss, timesteps, noise_scheduler):
"""Test SNR weight application with image size consideration"""
gamma = 5.0
image_sizes = [None, 64, (256, 256)]
for image_size in image_sizes:
result = apply_snr_weight(
loss,
timesteps,
noise_scheduler,
gamma,
v_prediction=False,
image_size=image_size
)
# Allow for broadcasting
assert result.shape[0] == loss.shape[0]
assert result.dtype == loss.dtype
def test_apply_debiased_estimation_variations(loss, timesteps, noise_scheduler):
"""Test debiased estimation with different image sizes and prediction types"""
image_sizes = [None, 64, (256, 256)]
prediction_types = [True, False]
for image_size in image_sizes:
for v_prediction in prediction_types:
result = apply_debiased_estimation(
loss,
timesteps,
noise_scheduler,
v_prediction=v_prediction,
image_size=image_size
)
# Allow for broadcasting
assert result.shape[0] == loss.shape[0]
assert result.dtype == loss.dtype

View File

@@ -1,6 +1,8 @@
import pytest import pytest
import torch import torch
import math
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.flux_train_utils import ( from library.flux_train_utils import (
get_noisy_model_input_and_timesteps, get_noisy_model_input_and_timesteps,
) )
@@ -218,3 +220,69 @@ def test_different_timestep_count(args, device):
assert timesteps.shape == (2,) assert timesteps.shape == (2,)
# Check that timesteps are within the proper range # Check that timesteps are within the proper range
assert torch.all(timesteps < 500) assert torch.all(timesteps < 500)
# New tests for dynamic timestep shifting
def test_dynamic_timestep_shifting(device):
"""Test the dynamic timestep shifting functionality"""
# Create a scheduler with dynamic shifting enabled
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=1.0,
use_dynamic_shifting=True
)
# Test different image sizes
test_sizes = [
(64, 64), # Small image
(256, 256), # Medium image
(512, 512), # Large image
(1024, 1024) # Very large image
]
for image_size in test_sizes:
# Simulate setting timesteps for inference
mu = math.log(1 + (image_size[0] * image_size[1]) / (256 * 256))
scheduler.set_timesteps(num_inference_steps=50, mu=mu)
# Check that sigmas have been dynamically shifted
assert len(scheduler.sigmas) == 51 # num_inference_steps + 1
assert scheduler.sigmas[0] <= 1.0 # Maximum sigma should be <= 1
assert scheduler.sigmas[-1] == 0.0 # Last sigma should always be 0
def test_sigma_generation_methods():
"""Test different sigma generation methods"""
# Test Karras sigmas
scheduler_karras = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
use_karras_sigmas=True
)
scheduler_karras.set_timesteps(num_inference_steps=50)
assert len(scheduler_karras.sigmas) == 51
# Test Exponential sigmas
scheduler_exp = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
use_exponential_sigmas=True
)
scheduler_exp.set_timesteps(num_inference_steps=50)
assert len(scheduler_exp.sigmas) == 51
def test_snr_calculation():
"""Test the SNR calculation method"""
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=1.0
)
# Prepare test timesteps
timesteps = torch.tensor([200, 600], dtype=torch.int32)
# Test with different image sizes
test_sizes = [None, 64, (256, 256)]
for image_size in test_sizes:
snr_values = scheduler.get_snr_for_timestep(timesteps, image_size)
# Check basic properties
assert snr_values.shape == torch.Size([2])
assert torch.all(snr_values >= 0) # SNR should always be non-negative

View File

@@ -327,7 +327,7 @@ class NetworkTrainer:
return noise_pred, target, timesteps, None return noise_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor:
if args.min_snr_gamma: if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred: if args.scale_v_pred_loss_like_noise_pred:
@@ -464,7 +464,7 @@ class NetworkTrainer:
is_train=is_train, is_train=is_train,
) )
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None: if weighting is not None:
loss = loss * weighting loss = loss * weighting
@@ -475,7 +475,7 @@ class NetworkTrainer:
loss_weights = batch["loss_weights"] # 各sampleごとのweight loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) loss = self.post_process_loss(loss, args, timesteps, noise_scheduler, latents)
return loss.mean() return loss.mean()