mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add DDO loss
This commit is contained in:
@@ -384,6 +384,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||
model_pred = unet(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
|
||||
@@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
with torch.no_grad():
|
||||
stream = torch.cuda.Stream()
|
||||
with torch.cuda.stream(stream):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
|
||||
stream.synchronize()
|
||||
stream.synchronize()
|
||||
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
# cpu to cuda
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
|
||||
module_to_cuda.weight.data = cuda_data_view
|
||||
|
||||
stream.synchronize()
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
@@ -505,41 +505,23 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float):
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float):
|
||||
"""
|
||||
DPO loss
|
||||
Diffusion DPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
call_unet: function to call unet
|
||||
apply_loss: function to apply loss
|
||||
loss: pairs of w, l losses B//2
|
||||
ref_loss: ref pairs of w, l losses B//2
|
||||
beta_dpo: beta_dpo weight
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- loss: mean loss of C, H, W
|
||||
- metrics:
|
||||
- total_loss: mean loss of C, H, W
|
||||
- raw_model_loss: mean loss of C, H, W
|
||||
- ref_loss: mean loss of C, H, W
|
||||
- implicit_acc: accumulated implicit of C, H, W
|
||||
|
||||
"""
|
||||
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
ref_noise_pred = call_unet()
|
||||
ref_loss = apply_loss(ref_noise_pred)
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1))
|
||||
model_diff = loss_w - loss_l
|
||||
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean(dim=1)
|
||||
|
||||
scale_term = -0.5 * beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
@@ -549,10 +531,10 @@ def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor],
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
metrics = {
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
"loss/diffusion_dpo_total_loss": loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(),
|
||||
"loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(),
|
||||
"loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
@@ -563,28 +545,15 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
mapo_loss: mapo weight
|
||||
mapo_weight: mapo weight
|
||||
num_train_timesteps: number of timesteps
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- loss: mean loss of C, H, W
|
||||
- metrics:
|
||||
- total_loss: mean loss of C, H, W
|
||||
- ratio_loss: mean ratio loss of C, H, W
|
||||
- model_losses_w: mean loss of w losses of C, H, W
|
||||
- model_losses_l: mean loss of l losses of C, H, W
|
||||
- win_score : mean win score of C, H, W
|
||||
- lose_score : mean lose score of C, H, W
|
||||
|
||||
"""
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
loss_w, loss_l = loss.chunk(2)
|
||||
log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - (
|
||||
snr * loss_l
|
||||
) / (torch.exp(snr * loss_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
@@ -592,141 +561,91 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000)
|
||||
ratio_losses = mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1)
|
||||
|
||||
metrics = {
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.detach().mean().item(),
|
||||
"model_losses_w": model_losses_w.detach().mean().item(),
|
||||
"model_losses_l": model_losses_l.detach().mean().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(),
|
||||
"model_losses_w": loss_w.detach().mean().item(),
|
||||
"model_losses_l": loss_l.detach().mean().item(),
|
||||
"win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(),
|
||||
"lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
class FlowMatchingDDOLoss(nn.Module):
|
||||
def __init__(self, alpha=4.0, beta=0.05):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
|
||||
def forward(
|
||||
self, v_theta: Tensor, v_theta_ref: Tensor, v_target: Tensor, time=None
|
||||
):
|
||||
"""
|
||||
Compute DDO loss for flow matching models
|
||||
|
||||
Args:
|
||||
v_theta: Vector field predicted by target model
|
||||
v_theta_ref: Vector field predicted by reference model
|
||||
v_target: Target vector field (e.g., straight-line for rectified flow)
|
||||
time: Time parameter t
|
||||
|
||||
Returns:
|
||||
DDO loss value
|
||||
"""
|
||||
# For flow matching, error is based on vector field difference
|
||||
error_theta = torch.sum((v_theta - v_target) ** 2, dim=[1, 2, 3])
|
||||
error_theta_ref = torch.sum((v_theta_ref - v_target) ** 2, dim=[1, 2, 3])
|
||||
|
||||
# Likelihood ratio approximation
|
||||
delta = error_theta_ref - error_theta
|
||||
scaled_delta = self.beta * delta
|
||||
|
||||
# Split batch into real and fake parts
|
||||
batch_size = v_theta.shape[0]
|
||||
half_batch = batch_size // 2
|
||||
|
||||
real_delta = scaled_delta[:half_batch]
|
||||
fake_delta = scaled_delta[half_batch:]
|
||||
|
||||
real_loss = -F.logsigmoid(real_delta).mean()
|
||||
fake_loss = -F.logsigmoid(-fake_delta).mean()
|
||||
|
||||
loss = real_loss + self.alpha * fake_loss
|
||||
|
||||
return loss
|
||||
|
||||
def compute_target_velocity(x_t: Tensor, t: Tensor):
|
||||
def ddo_loss(
|
||||
loss: Tensor,
|
||||
ref_loss: Tensor,
|
||||
ddo_alpha: float=4.0,
|
||||
ddo_beta: float=0.05,
|
||||
weighting: Tensor | None=None
|
||||
):
|
||||
"""
|
||||
Compute the target velocity vector field for flow matching.
|
||||
|
||||
For rectified flow, the target velocity is the straight-line path derivative.
|
||||
|
||||
Calculate DDO loss for flow matching diffusion models.
|
||||
|
||||
This implementation follows the paper's approach:
|
||||
1. Use prediction errors as proxy for log likelihood ratio
|
||||
2. Apply sigmoid to create a discriminator from this ratio
|
||||
3. Optimize using the standard GAN discriminator loss
|
||||
|
||||
Args:
|
||||
x_t: Points along the path at time t (batch_size, channels, height, width)
|
||||
t: Time values in [0,1] (batch_size,)
|
||||
|
||||
loss: loss B, N
|
||||
ref_loss: ref loss B, N
|
||||
ddo_alpha: Weight for the fake sample term
|
||||
ddo_beta: Scaling factor for the likelihood ratio
|
||||
weighting: Optional time-dependent weighting
|
||||
|
||||
Returns:
|
||||
Target velocity vectors v(x_t, t) for flow matching
|
||||
The DDO loss value
|
||||
"""
|
||||
batch_size = x_t.shape[0]
|
||||
|
||||
# Get corresponding data and noise endpoints
|
||||
with torch.no_grad():
|
||||
# For each interpolated point, we need the endpoints of its path
|
||||
# In practice, these might come from a cache or be passed as arguments
|
||||
x1 = get_data_endpoints(x_t, t) # Real data endpoint (t=0)
|
||||
x0 = get_noise_endpoints(x_t, t) # Noise endpoint (t=1)
|
||||
|
||||
# Reshape t for broadcasting
|
||||
t = t.view(batch_size, 1, 1, 1)
|
||||
|
||||
# For standard rectified flow, the target velocity is constant along the path:
|
||||
# v(x_t, t) = x1 - x0
|
||||
v_target = x1 - x0
|
||||
|
||||
# For time-dependent velocity fields (non-rectified), we would scale by time:
|
||||
# v_target = v_target * g(t) # where g(t) is a time-dependent scaling function
|
||||
|
||||
return v_target
|
||||
|
||||
|
||||
def get_data_endpoints(x_t: Tensor, t: Tensor):
|
||||
"""
|
||||
Get the data endpoints (t=0) for the given points on the path.
|
||||
|
||||
For training with real data, this would typically use the encoded real data.
|
||||
For inference or when using generated endpoints, we'd solve for them.
|
||||
|
||||
Args:
|
||||
x_t: Points on the path at time t
|
||||
t: Time values
|
||||
|
||||
Returns:
|
||||
The data endpoints (x at t=0)
|
||||
"""
|
||||
# Solve for x1 using the straight-line path: x_t = (1-t)*x1 + t*x0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
x0 = torch.randn_like(x_t) # Noise endpoint
|
||||
|
||||
# Solve for x1: x1 = (x_t - t*x0) / (1-t)
|
||||
# Add small epsilon to prevent division by zero
|
||||
epsilon = 1e-8
|
||||
x1 = (x_t - t * x0) / (torch.clamp(1 - t, min=epsilon))
|
||||
|
||||
return x1
|
||||
|
||||
|
||||
def get_noise_endpoints(x_t: Tensor, t: Tensor):
|
||||
"""
|
||||
Get the noise endpoints (t=1) for the given points on the path.
|
||||
|
||||
For standard rectified flow, this is typically Gaussian noise.
|
||||
|
||||
Args:
|
||||
x_t: Points on the path at time t
|
||||
t: Time values
|
||||
|
||||
Returns:
|
||||
The noise endpoints (x at t=1)
|
||||
"""
|
||||
|
||||
# Generate noise samples matching the shape of x_t
|
||||
x0 = torch.randn_like(x_t)
|
||||
|
||||
return x0
|
||||
# Calculate per-sample MSE between predictions and target
|
||||
# Flatten spatial and channel dimensions, keeping batch dimension
|
||||
# target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
|
||||
# ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1)
|
||||
|
||||
# Apply weighting if provided (e.g., for time-dependent importance)
|
||||
if weighting is not None:
|
||||
if isinstance(weighting, tuple):
|
||||
# Use first element if it's a tuple
|
||||
weighting = weighting[0]
|
||||
if weighting.ndim > 1:
|
||||
# Ensure weighting is the right shape
|
||||
weighting = weighting.view(-1)
|
||||
loss = loss * weighting
|
||||
ref_loss = ref_loss * weighting
|
||||
|
||||
# Calculate the log likelihood ratio
|
||||
# For flow matching, lower error = higher likelihood
|
||||
# So the log ratio is proportional to negative of error difference
|
||||
log_ratio = ddo_beta * (ref_loss - loss)
|
||||
|
||||
# Divide batch into real and fake samples (mid-point split)
|
||||
# In this implementation, the entire batch is treated as real samples
|
||||
# and each sample is compared against its own reference prediction
|
||||
# This approach works because the reference model (with LoRA disabled)
|
||||
# produces predictions that serve as the "fake" distribution
|
||||
|
||||
# Loss for real samples: maximize log σ(ratio)
|
||||
real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio)
|
||||
real_loss = real_loss_terms.mean()
|
||||
|
||||
# Loss for fake samples: maximize log(1-σ(ratio))
|
||||
# Since we're using the same batch for both real and fake,
|
||||
# we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference
|
||||
fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio)
|
||||
fake_loss = ddo_alpha * fake_loss_terms.mean()
|
||||
|
||||
total_loss = real_loss + fake_loss
|
||||
|
||||
metrics = {
|
||||
"loss/ddo_real": real_loss.detach().item(),
|
||||
"loss/ddo_fake": fake_loss.detach().item(),
|
||||
"loss/ddo_total": total_loss.detach().item(),
|
||||
"ddo_log_ratio_mean": log_ratio.detach().mean().item(),
|
||||
}
|
||||
|
||||
return total_loss, metrics
|
||||
|
||||
|
||||
"""
|
||||
|
||||
119
train_network.py
119
train_network.py
@@ -37,6 +37,7 @@ import library.huggingface_util as huggingface_util
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import (
|
||||
apply_snr_weight,
|
||||
ddo_loss,
|
||||
get_weighted_text_embeddings,
|
||||
prepare_scheduler_for_custom_training,
|
||||
scale_v_prediction_loss_like_noise_prediction,
|
||||
@@ -45,8 +46,7 @@ from library.custom_train_functions import (
|
||||
apply_masked_loss,
|
||||
diffusion_dpo_loss,
|
||||
mapo_loss,
|
||||
FlowMatchingDDOLoss,
|
||||
compute_target_velocity,
|
||||
calculate_ddo_loss_for_dit_flow_matching,
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
@@ -270,7 +270,7 @@ class NetworkTrainer:
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
@@ -324,8 +324,8 @@ class NetworkTrainer:
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, noisy_latents, target, timesteps, None
|
||||
sigmas = timesteps / noise_scheduler.config.num_train_timesteps
|
||||
return noise_pred, noisy_latents, target, sigmas, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
@@ -452,7 +452,8 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, noisy_latents, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
|
||||
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -466,17 +467,52 @@ class NetworkTrainer:
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
# Compute DDO loss
|
||||
ddo_loss = FlowMatchingDDOLoss(alpha=args.ddo_beta or 4.0, beta=args.ddo_alpha or 0.05)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
ref_noise_pred, _noisy_latents, ref_target, ref_timesteps, _weighting = self.get_noise_pred_and_target(
|
||||
if args.ddo_beta is not None or args.ddo_alpha is not None:
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
torch.rand_like(latents),
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=False,
|
||||
)
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
|
||||
# Apply DDO loss
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c)
|
||||
loss, metrics_ddo = ddo_loss(
|
||||
loss,
|
||||
ref_loss,
|
||||
args.ddo_alpha or 4.0,
|
||||
args.ddo_beta or 0.05,
|
||||
weighting
|
||||
)
|
||||
metrics = {**metrics, **metrics_ddo}
|
||||
elif args.beta_dpo is not None:
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
@@ -485,60 +521,15 @@ class NetworkTrainer:
|
||||
train_unet,
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
|
||||
# Combine real and fake batches
|
||||
combined_latents = torch.cat([noise_pred, ref_noise_pred], dim=0)
|
||||
combined_t = torch.cat([timesteps, ref_timesteps], dim=0)
|
||||
|
||||
# Compute target vector field (straight path for rectified flow)
|
||||
v_target = compute_target_velocity(combined_latents, combined_t)
|
||||
v_theta = noise_pred
|
||||
v_theta_ref = ref_noise_pred
|
||||
|
||||
loss = ddo_loss(v_theta, v_theta_ref, v_target, combined_t)
|
||||
else:
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
if args.beta_dpo is not None:
|
||||
def call_unet():
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
ref_noise_pred, _noisy_latents, ref_target, ref_timesteps, _weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
torch.rand_like(latents),
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
return ref_noise_pred
|
||||
def apply_loss(ref_noise_pred):
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
ref_loss = train_util.conditional_loss(
|
||||
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
ref_loss = apply_masked_loss(ref_loss, batch)
|
||||
return ref_loss
|
||||
|
||||
loss, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, args.beta_dpo)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
ref_loss = train_util.conditional_loss(
|
||||
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
|
||||
loss, metrics = diffusion_dpo_loss(loss, ref_loss, args.beta_dpo)
|
||||
elif args.mapo_weight is not None:
|
||||
loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user