mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Rework DDO loss
This commit is contained in:
@@ -347,16 +347,23 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
timesteps: torch.FloatTensor | None=None,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
noisy_model_input, rand_timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = rand_timesteps
|
||||
else:
|
||||
# Convert timesteps into sigmas
|
||||
sigmas: torch.FloatTensor = timesteps - noise_scheduler.config.num_train_timesteps
|
||||
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
|
||||
@@ -568,85 +568,36 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
|
||||
|
||||
metrics = {
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.detach().mean().item(),
|
||||
"model_losses_w": loss_w.detach().mean().item(),
|
||||
"model_losses_l": loss_l.detach().mean().item(),
|
||||
"win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
"loss/diffusion_dpo_total": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(),
|
||||
"loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(),
|
||||
"loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(),
|
||||
"loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
def ddo_loss(loss: Tensor, ref_loss: Tensor, ddo_alpha: float = 4.0, ddo_beta: float = 0.05, weighting: Tensor | None = None):
|
||||
"""
|
||||
Calculate DDO loss for flow matching diffusion models.
|
||||
|
||||
This implementation follows the paper's approach:
|
||||
1. Use prediction errors as proxy for log likelihood ratio
|
||||
2. Apply sigmoid to create a discriminator from this ratio
|
||||
3. Optimize using the standard GAN discriminator loss
|
||||
|
||||
Args:
|
||||
loss: loss B, N
|
||||
ref_loss: ref loss B, N
|
||||
ddo_alpha: Weight for the fake sample term
|
||||
ddo_beta: Scaling factor for the likelihood ratio
|
||||
weighting: Optional time-dependent weighting
|
||||
|
||||
Returns:
|
||||
The DDO loss value
|
||||
"""
|
||||
# Calculate per-sample MSE between predictions and target
|
||||
# Flatten spatial and channel dimensions, keeping batch dimension
|
||||
# target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
|
||||
# ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
|
||||
|
||||
# Apply weighting if provided (e.g., for time-dependent importance)
|
||||
if weighting is not None:
|
||||
if isinstance(weighting, tuple):
|
||||
# Use first element if it's a tuple
|
||||
weighting = weighting[0]
|
||||
if weighting.ndim > 1:
|
||||
# Ensure weighting is the right shape
|
||||
weighting = weighting.view(-1)
|
||||
loss = loss * weighting
|
||||
ref_loss = ref_loss * weighting
|
||||
|
||||
# Calculate the log likelihood ratio
|
||||
# For flow matching, lower error = higher likelihood
|
||||
# So the log ratio is proportional to negative of error difference
|
||||
def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05):
|
||||
ref_loss = ref_loss.detach() # Ensure no gradients to reference
|
||||
log_ratio = ddo_beta * (ref_loss - loss)
|
||||
|
||||
# Divide batch into real and fake samples (mid-point split)
|
||||
# In this implementation, the entire batch is treated as real samples
|
||||
# and each sample is compared against its own reference prediction
|
||||
# This approach works because the reference model (with LoRA disabled)
|
||||
# produces predictions that serve as the "fake" distribution
|
||||
|
||||
# Loss for real samples: maximize log σ(ratio)
|
||||
real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio)
|
||||
real_loss = real_loss_terms.mean()
|
||||
|
||||
# Loss for fake samples: maximize log(1-σ(ratio))
|
||||
# Since we're using the same batch for both real and fake,
|
||||
# we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference
|
||||
fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio)
|
||||
fake_loss = ddo_alpha * fake_loss_terms.mean()
|
||||
|
||||
real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean()
|
||||
total_loss = real_loss + fake_loss
|
||||
|
||||
metrics = {
|
||||
"loss/ddo_real": real_loss.detach().item(),
|
||||
"loss/ddo_fake": fake_loss.detach().item(),
|
||||
"loss/ddo_total": total_loss.detach().item(),
|
||||
"ddo_log_ratio_mean": log_ratio.detach().mean().item(),
|
||||
"loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(),
|
||||
}
|
||||
|
||||
# logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}")
|
||||
# logger.debug(f"difference: {(ref_loss - loss).mean().item()}")
|
||||
# logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}")
|
||||
# logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}")
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
# Perlin Noise
|
||||
|
||||
@@ -270,10 +270,14 @@ class NetworkTrainer:
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
timesteps=None
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
noise, noisy_latents, rand_timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = rand_timesteps
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
@@ -475,34 +479,34 @@ class NetworkTrainer:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=False,
|
||||
)
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=False,
|
||||
timesteps=timesteps,
|
||||
)
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
|
||||
# Apply DDO loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, ref_timesteps, noise_scheduler)
|
||||
ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None and ref_weighting is not None:
|
||||
ddo_weighting = weighting * ref_weighting
|
||||
loss, metrics_ddo = ddo_loss(
|
||||
loss.mean(dim=(1, 2, 3)),
|
||||
ref_loss.mean(dim=(1, 2, 3)),
|
||||
loss.mean(dim=(1, 2, 3)) * (weighting if weighting is not None else 1),
|
||||
ref_loss.mean(dim=(1, 2, 3)) * (ref_weighting if ref_weighting is not None else 1),
|
||||
args.ddo_alpha or 4.0,
|
||||
args.ddo_beta or 0.05,
|
||||
weighting
|
||||
)
|
||||
metrics = {**metrics, **metrics_ddo}
|
||||
elif args.beta_dpo is not None:
|
||||
|
||||
Reference in New Issue
Block a user