mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Update PO cached latents, move out functions, update calls
This commit is contained in:
@@ -4,7 +4,7 @@ import argparse
|
||||
import random
|
||||
import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Callable
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -502,6 +502,106 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float):
|
||||
"""
|
||||
DPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
call_unet: function to call unet
|
||||
apply_loss: function to apply loss
|
||||
beta_dpo: beta_dpo weight
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- loss: mean loss of C, H, W
|
||||
- metrics:
|
||||
- total_loss: mean loss of C, H, W
|
||||
- raw_model_loss: mean loss of C, H, W
|
||||
- ref_loss: mean loss of C, H, W
|
||||
- implicit_acc: accumulated implicit of C, H, W
|
||||
|
||||
"""
|
||||
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
ref_noise_pred = call_unet()
|
||||
ref_loss = apply_loss(ref_noise_pred)
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
|
||||
scale_term = -0.5 * beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
metrics = {
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
MaPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
mapo_loss: mapo weight
|
||||
num_train_timesteps: number of timesteps
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- loss: mean loss of C, H, W
|
||||
- metrics:
|
||||
- total_loss: mean loss of C, H, W
|
||||
- ratio_loss: mean ratio loss of C, H, W
|
||||
- model_losses_w: mean loss of w losses of C, H, W
|
||||
- model_losses_l: mean loss of l losses of C, H, W
|
||||
- win_score : mean win score of C, H, W
|
||||
- lose_score : mean lose score of C, H, W
|
||||
|
||||
"""
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
|
||||
ratio_losses = mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
|
||||
metrics = {
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.detach().mean().item(),
|
||||
"model_losses_w": model_losses_w.detach().mean().item(),
|
||||
"model_losses_l": model_losses_l.detach().mean().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
|
||||
Reference in New Issue
Block a user