Preference optimization with MaPO and Diffusion-DPO

This commit is contained in:
rockerBOO
2024-06-19 16:22:11 -04:00
parent e5bab69e3a
commit 44fa71c78f
3 changed files with 365 additions and 102 deletions

View File

@@ -906,10 +906,93 @@ class NetworkTrainer:
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.beta_dpo is not None:
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
model_loss_w, model_loss_l = model_loss.chunk(2)
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
model_diff = model_loss_w - model_loss_l
# ref loss
with torch.no_grad():
# disable network for reference
accelerator.unwrap_model(network).set_multiplier(0.0)
with accelerator.autocast():
ref_noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
ref_loss = train_util.conditional_loss(
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
ref_loss = apply_masked_loss(ref_loss, batch)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(multipliers)
scale_term = -0.5 * args.beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
accelerator.log({
"total_loss": model_loss.detach().mean().item(),
"raw_model_loss": raw_model_loss.detach().mean().item(),
"ref_loss": raw_ref_loss.detach().item(),
"implicit_acc": implicit_acc.detach().item(),
}, step=global_step)
elif args.mapo_weight is not None:
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
snr = 0.5
model_losses_w, model_losses_l = model_loss.chunk(2)
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
snr * model_losses_l
) / (torch.exp(snr * model_losses_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
ratio_losses = args.mapo_weight * ratio
# Full MaPO loss
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
accelerator.log({
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.mean().detach().item(),
"model_losses_w": model_losses_w.mean().detach().item(),
"model_losses_l": model_losses_l.mean().detach().item(),
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
.mean()
.detach()
.item(),
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
.mean()
.detach()
.item()
}, step=global_step)
else:
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
timesteps = [timesteps[0]] * loss.shape[0]
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
@@ -983,7 +1066,8 @@ class NetworkTrainer:
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
# accelerator.log(logs, step=epoch + 1)
accelerator.log(logs, step=global_step)
accelerator.wait_for_everyone()