This commit is contained in:
Dave Lage
2026-02-20 09:30:55 -08:00
committed by GitHub
8 changed files with 782 additions and 60 deletions

View File

@@ -31,6 +31,7 @@ from packaging.version import Version
import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
init_ipex()
@@ -60,7 +61,7 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
AutoencoderKL,
)
from library import custom_train_functions, sd3_utils
from library import custom_train_functions, sd3_utils, flux_train_utils
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import numpy as np
@@ -6107,7 +6108,7 @@ def get_noise_noisy_latents_and_timesteps(
return noise, noisy_latents, timesteps
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, latents: torch.Tensor, noise_scheduler) -> Optional[torch.Tensor]:
if not (args.loss_type == "huber" or args.loss_type == "smooth_l1"):
return None
@@ -6116,10 +6117,23 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps
result = torch.exp(-alpha * timesteps) * args.huber_scale
elif args.huber_schedule == "snr":
if not hasattr(noise_scheduler, "alphas_cumprod"):
if hasattr(noise_scheduler, "sigmas"):
# Need to adjust the timesteps based on the latent dimensions
if args.timestep_sampling == "flux_shift":
_, _, h, w = latents.shape
mu = flux_train_utils.get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
alphas_cumprod = get_alphas_cumprod(noise_scheduler, mu)
else:
alphas_cumprod = get_alphas_cumprod(noise_scheduler)
else:
alphas_cumprod = get_alphas_cumprod(noise_scheduler)
if alphas_cumprod is None:
raise NotImplementedError("Huber schedule 'snr' is not supported with the current model.")
alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps.cpu())
timesteps_indices = index_for_timesteps(timesteps, noise_scheduler)
alphas_cumprod = torch.index_select(alphas_cumprod.to(timesteps.device), 0, timesteps_indices)
sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5
result = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c
result = result.to(timesteps.device)
elif args.huber_schedule == "constant":
@@ -6129,6 +6143,67 @@ def get_huber_threshold_if_needed(args, timesteps: torch.Tensor, noise_scheduler
return result
def index_for_timesteps(timesteps: torch.Tensor, noise_scheduler) -> torch.Tensor:
if hasattr(noise_scheduler, "index_for_timestep"):
noise_scheduler.timesteps = noise_scheduler.timesteps.to(timesteps.device)
# Convert timesteps to appropriate indices using the scheduler's method
indices = []
for t in timesteps:
# Make sure t is a tensor with the right device
t_tensor = t if isinstance(t, torch.Tensor) else torch.tensor([t], device=timesteps.device)[0]
try:
# Use the scheduler's method to get the correct index
idx = noise_scheduler.index_for_timestep(t_tensor)
indices.append(idx)
except IndexError:
# Handle case where no exact match is found
schedule_timesteps = noise_scheduler.timesteps
closest_idx = torch.abs(schedule_timesteps - t_tensor).argmin().item()
indices.append(closest_idx)
timesteps_indices = torch.tensor(indices, device=timesteps.device, dtype=torch.long)
else:
timesteps_indices = timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
return timesteps_indices
def timesteps_to_indices(timesteps: torch.Tensor, num_train_timesteps: int):
"""
Convert the timesteps into indices by converting the timestep into an long integer.
Accounts for timestep being within range 0 to 1 and 1 to 1000.
"""
# Check if timesteps are normalized (between 0-1) or absolute (1-1000)
if torch.max(timesteps) <= 1.0:
# Timesteps are normalized, scale them to indices
timesteps_indices = (timesteps * (num_train_timesteps - 1)).round().to(torch.long)
else:
# Timesteps are already in the range of 1 to num_train_timesteps
# We may need to adjust indices if timesteps start from 1 but indices from 0
timesteps_indices = (timesteps - 1).round().to(torch.long).clamp(0, num_train_timesteps - 1)
return timesteps_indices
def get_alphas_cumprod(noise_scheduler, mu=None) -> Optional[torch.Tensor]:
"""
Get the cumulative product of the alpha values across the timesteps.
We use the noise scheduler to get the timesteps or use alphas_cumprod.
"""
if hasattr(noise_scheduler, "alphas_cumprod"):
alphas_cumprod = noise_scheduler.alphas_cumprod
elif hasattr(noise_scheduler, "sigmas"):
if noise_scheduler.config.use_dynamic_shifting is True:
sigmas = noise_scheduler.time_shift(mu, 1.0, noise_scheduler.sigmas)
else:
# Since we don't have alphas_cumprod directly, we can derive it from sigmas
sigmas = noise_scheduler.sigmas
# In many diffusion models, sigma² = (1-α)/α where α is the cumulative product of alphas
# So we can derive alphas_cumprod from sigmas
alphas_cumprod = 1.0 / (1.0 + sigmas**2)
else:
return None
return alphas_cumprod
def conditional_loss(
model_pred: torch.Tensor, target: torch.Tensor, loss_type: str, reduction: str, huber_c: Optional[torch.Tensor] = None