mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Preference optimization with MaPO and Diffusion-DPO
This commit is contained in:
@@ -906,10 +906,93 @@ class NetworkTrainer:
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
if args.beta_dpo is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
# disable network for reference
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
|
||||
with accelerator.autocast():
|
||||
ref_noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
ref_loss = train_util.conditional_loss(
|
||||
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
ref_loss = apply_masked_loss(ref_loss, batch)
|
||||
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
scale_term = -0.5 * args.beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
}, step=global_step)
|
||||
elif args.mapo_weight is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
|
||||
ratio_losses = args.mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.mean().detach().item(),
|
||||
"model_losses_w": model_losses_w.mean().detach().item(),
|
||||
"model_losses_l": model_losses_l.mean().detach().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item()
|
||||
}, step=global_step)
|
||||
else:
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
timesteps = [timesteps[0]] * loss.shape[0]
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
@@ -983,7 +1066,8 @@ class NetworkTrainer:
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user