Add DDO loss

This commit is contained in:
rockerBOO
2025-04-30 03:34:19 -04:00
parent 8e8243a423
commit 9a2101a040
4 changed files with 160 additions and 248 deletions

View File

@@ -384,6 +384,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
model_pred = unet(
img=img,
img_ids=img_ids,

View File

@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
with torch.no_grad():
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
stream.synchronize()
stream.synchronize()
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view
stream.synchronize()
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

View File

@@ -505,41 +505,23 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
loss = loss * mask_image
return loss
def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float):
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
"""
DPO loss
Diffusion DPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
call_unet: function to call unet
apply_loss: function to apply loss
loss: pairs of w, l losses B//2
ref_loss: ref pairs of w, l losses B//2
beta_dpo: beta_dpo weight
Returns:
tuple:
- loss: mean loss of C, H, W
- metrics:
- total_loss: mean loss of C, H, W
- raw_model_loss: mean loss of C, H, W
- ref_loss: mean loss of C, H, W
- implicit_acc: accumulated implicit of C, H, W
"""
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
model_loss_w, model_loss_l = model_loss.chunk(2)
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
model_diff = model_loss_w - model_loss_l
# ref loss
with torch.no_grad():
ref_noise_pred = call_unet()
ref_loss = apply_loss(ref_noise_pred)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
loss_w, loss_l = loss.chunk(2)
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
model_diff = loss_w - loss_l
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean(dim=1)
scale_term = -0.5 * beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
@@ -549,10 +531,10 @@ def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor],
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
metrics = {
"total_loss": model_loss.detach().mean().item(),
"raw_model_loss": raw_model_loss.detach().mean().item(),
"ref_loss": raw_ref_loss.detach().item(),
"implicit_acc": implicit_acc.detach().item(),
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
}
return loss, metrics
@@ -563,28 +545,15 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
Args:
loss: pairs of w, l losses B//2, C, H, W
mapo_loss: mapo weight
mapo_weight: mapo weight
num_train_timesteps: number of timesteps
Returns:
tuple:
- loss: mean loss of C, H, W
- metrics:
- total_loss: mean loss of C, H, W
- ratio_loss: mean ratio loss of C, H, W
- model_losses_w: mean loss of w losses of C, H, W
- model_losses_l: mean loss of l losses of C, H, W
- win_score : mean win score of C, H, W
- lose_score : mean lose score of C, H, W
"""
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
snr = 0.5
model_losses_w, model_losses_l = model_loss.chunk(2)
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
snr * model_losses_l
) / (torch.exp(snr * model_losses_l) - 1)
loss_w, loss_l = loss.chunk(2)
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (
snr * loss_l
) / (torch.exp(snr * loss_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
@@ -592,141 +561,91 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
ratio_losses = mapo_weight * ratio
# Full MaPO loss
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
metrics = {
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.detach().mean().item(),
"model_losses_w": model_losses_w.detach().mean().item(),
"model_losses_l": model_losses_l.detach().mean().item(),
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(),
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(),
"model_losses_w": loss_w.detach().mean().item(),
"model_losses_l": loss_l.detach().mean().item(),
"win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
"lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
}
return loss, metrics
class FlowMatchingDDOLoss(nn.Module):
def __init__(self, alpha=4.0, beta=0.05):
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(
self, v_theta: Tensor, v_theta_ref: Tensor, v_target: Tensor, time=None
):
"""
Compute DDO loss for flow matching models
Args:
v_theta: Vector field predicted by target model
v_theta_ref: Vector field predicted by reference model
v_target: Target vector field (e.g., straight-line for rectified flow)
time: Time parameter t
Returns:
DDO loss value
"""
# For flow matching, error is based on vector field difference
error_theta = torch.sum((v_theta - v_target) ** 2, dim=[1, 2, 3])
error_theta_ref = torch.sum((v_theta_ref - v_target) ** 2, dim=[1, 2, 3])
# Likelihood ratio approximation
delta = error_theta_ref - error_theta
scaled_delta = self.beta * delta
# Split batch into real and fake parts
batch_size = v_theta.shape[0]
half_batch = batch_size // 2
real_delta = scaled_delta[:half_batch]
fake_delta = scaled_delta[half_batch:]
real_loss = -F.logsigmoid(real_delta).mean()
fake_loss = -F.logsigmoid(-fake_delta).mean()
loss = real_loss + self.alpha * fake_loss
return loss
def compute_target_velocity(x_t: Tensor, t: Tensor):
def ddo_loss(
loss: Tensor,
ref_loss: Tensor,
ddo_alpha: float=4.0,
ddo_beta: float=0.05,
weighting: Tensor | None=None
):
"""
Compute the target velocity vector field for flow matching.
For rectified flow, the target velocity is the straight-line path derivative.
Calculate DDO loss for flow matching diffusion models.
This implementation follows the paper's approach:
1. Use prediction errors as proxy for log likelihood ratio
2. Apply sigmoid to create a discriminator from this ratio
3. Optimize using the standard GAN discriminator loss
Args:
x_t: Points along the path at time t (batch_size, channels, height, width)
t: Time values in [0,1] (batch_size,)
loss: loss B, N
ref_loss: ref loss B, N
ddo_alpha: Weight for the fake sample term
ddo_beta: Scaling factor for the likelihood ratio
weighting: Optional time-dependent weighting
Returns:
Target velocity vectors v(x_t, t) for flow matching
The DDO loss value
"""
batch_size = x_t.shape[0]
# Get corresponding data and noise endpoints
with torch.no_grad():
# For each interpolated point, we need the endpoints of its path
# In practice, these might come from a cache or be passed as arguments
x1 = get_data_endpoints(x_t, t) # Real data endpoint (t=0)
x0 = get_noise_endpoints(x_t, t) # Noise endpoint (t=1)
# Reshape t for broadcasting
t = t.view(batch_size, 1, 1, 1)
# For standard rectified flow, the target velocity is constant along the path:
# v(x_t, t) = x1 - x0
v_target = x1 - x0
# For time-dependent velocity fields (non-rectified), we would scale by time:
# v_target = v_target * g(t) # where g(t) is a time-dependent scaling function
return v_target
def get_data_endpoints(x_t: Tensor, t: Tensor):
"""
Get the data endpoints (t=0) for the given points on the path.
For training with real data, this would typically use the encoded real data.
For inference or when using generated endpoints, we'd solve for them.
Args:
x_t: Points on the path at time t
t: Time values
Returns:
The data endpoints (x at t=0)
"""
# Solve for x1 using the straight-line path: x_t = (1-t)*x1 + t*x0
t = t.view(-1, 1, 1, 1)
x0 = torch.randn_like(x_t) # Noise endpoint
# Solve for x1: x1 = (x_t - t*x0) / (1-t)
# Add small epsilon to prevent division by zero
epsilon = 1e-8
x1 = (x_t - t * x0) / (torch.clamp(1 - t, min=epsilon))
return x1
def get_noise_endpoints(x_t: Tensor, t: Tensor):
"""
Get the noise endpoints (t=1) for the given points on the path.
For standard rectified flow, this is typically Gaussian noise.
Args:
x_t: Points on the path at time t
t: Time values
Returns:
The noise endpoints (x at t=1)
"""
# Generate noise samples matching the shape of x_t
x0 = torch.randn_like(x_t)
return x0
# Calculate per-sample MSE between predictions and target
# Flatten spatial and channel dimensions, keeping batch dimension
# target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
# ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
# Apply weighting if provided (e.g., for time-dependent importance)
if weighting is not None:
if isinstance(weighting, tuple):
# Use first element if it's a tuple
weighting = weighting[0]
if weighting.ndim > 1:
# Ensure weighting is the right shape
weighting = weighting.view(-1)
loss = loss * weighting
ref_loss = ref_loss * weighting
# Calculate the log likelihood ratio
# For flow matching, lower error = higher likelihood
# So the log ratio is proportional to negative of error difference
log_ratio = ddo_beta * (ref_loss - loss)
# Divide batch into real and fake samples (mid-point split)
# In this implementation, the entire batch is treated as real samples
# and each sample is compared against its own reference prediction
# This approach works because the reference model (with LoRA disabled)
# produces predictions that serve as the "fake" distribution
# Loss for real samples: maximize log σ(ratio)
real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio)
real_loss = real_loss_terms.mean()
# Loss for fake samples: maximize log(1-σ(ratio))
# Since we're using the same batch for both real and fake,
# we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference
fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio)
fake_loss = ddo_alpha * fake_loss_terms.mean()
total_loss = real_loss + fake_loss
metrics = {
"loss/ddo_real": real_loss.detach().item(),
"loss/ddo_fake": fake_loss.detach().item(),
"loss/ddo_total": total_loss.detach().item(),
"ddo_log_ratio_mean": log_ratio.detach().mean().item(),
}
return total_loss, metrics
"""

View File

@@ -37,6 +37,7 @@ import library.huggingface_util as huggingface_util
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import (
apply_snr_weight,
ddo_loss,
get_weighted_text_embeddings,
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
@@ -45,8 +46,7 @@ from library.custom_train_functions import (
apply_masked_loss,
diffusion_dpo_loss,
mapo_loss,
FlowMatchingDDOLoss,
compute_target_velocity,
calculate_ddo_loss_for_dit_flow_matching,
)
from library.utils import setup_logging, add_logging_arguments
@@ -270,7 +270,7 @@ class NetworkTrainer:
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
@@ -324,8 +324,8 @@ class NetworkTrainer:
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, noisy_latents, target, timesteps, None
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
return noise_pred, noisy_latents, target, sigmas, timesteps, None
def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
@@ -452,7 +452,8 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, noisy_latents, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -466,17 +467,52 @@ class NetworkTrainer:
is_train=is_train,
)
if args.ddo_beta is not None or args.ddo_alpha is not None:
# Compute DDO loss
ddo_loss = FlowMatchingDDOLoss(alpha=args.ddo_beta or 4.0, beta=args.ddo_alpha or 0.05)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
accelerator.unwrap_model(network).set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
ref_noise_pred, _noisy_latents, ref_target, ref_timesteps, _weighting = self.get_noise_pred_and_target(
if args.ddo_beta is not None or args.ddo_alpha is not None:
with torch.no_grad():
accelerator.unwrap_model(network).set_multiplier(0.0)
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
torch.rand_like(latents),
latents,
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=False,
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
# Apply DDO loss
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
loss, metrics_ddo = ddo_loss(
loss,
ref_loss,
args.ddo_alpha or 4.0,
args.ddo_beta or 0.05,
weighting
)
metrics = {**metrics, **metrics_ddo}
elif args.beta_dpo is not None:
with torch.no_grad():
accelerator.unwrap_model(network).set_multiplier(0.0)
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds,
unet,
@@ -485,60 +521,15 @@ class NetworkTrainer:
train_unet,
is_train=is_train,
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
# Combine real and fake batches
combined_latents = torch.cat([noise_pred, ref_noise_pred], dim=0)
combined_t = torch.cat([timesteps, ref_timesteps], dim=0)
# Compute target vector field (straight path for rectified flow)
v_target = compute_target_velocity(combined_latents, combined_t)
v_theta = noise_pred
v_theta_ref = ref_noise_pred
loss = ddo_loss(v_theta, v_theta_ref, v_target, combined_t)
else:
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
if args.beta_dpo is not None:
def call_unet():
accelerator.unwrap_model(network).set_multiplier(0.0)
with torch.no_grad(), accelerator.autocast():
ref_noise_pred, _noisy_latents, ref_target, ref_timesteps, _weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
torch.rand_like(latents),
batch,
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
is_train=is_train,
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
return ref_noise_pred
def apply_loss(ref_noise_pred):
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
ref_loss = train_util.conditional_loss(
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
ref_loss = apply_masked_loss(ref_loss, batch)
return ref_loss
loss, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, args.beta_dpo)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
ref_loss = train_util.conditional_loss(
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss, metrics = diffusion_dpo_loss(loss, ref_loss, args.beta_dpo)
elif args.mapo_weight is not None:
loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps)
else: