mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Update PO cached latents, move out functions, update calls
This commit is contained in:
@@ -336,24 +336,24 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
latents: torch.FloatTensor,
|
||||
batch: dict[str, torch.Tensor],
|
||||
text_encoder_conds,
|
||||
unet: flux_models.Flux,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
|
||||
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
@@ -4,7 +4,7 @@ import argparse
|
||||
import random
|
||||
import re
|
||||
from torch.types import Number
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, Callable
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -502,6 +502,106 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
|
||||
loss = loss * mask_image
|
||||
return loss
|
||||
|
||||
def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float):
|
||||
"""
|
||||
DPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
call_unet: function to call unet
|
||||
apply_loss: function to apply loss
|
||||
beta_dpo: beta_dpo weight
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- loss: mean loss of C, H, W
|
||||
- metrics:
|
||||
- total_loss: mean loss of C, H, W
|
||||
- raw_model_loss: mean loss of C, H, W
|
||||
- ref_loss: mean loss of C, H, W
|
||||
- implicit_acc: accumulated implicit of C, H, W
|
||||
|
||||
"""
|
||||
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
ref_noise_pred = call_unet()
|
||||
ref_loss = apply_loss(ref_noise_pred)
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
|
||||
scale_term = -0.5 * beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
metrics = {
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
|
||||
"""
|
||||
MaPO loss
|
||||
|
||||
Args:
|
||||
loss: pairs of w, l losses B//2, C, H, W
|
||||
mapo_loss: mapo weight
|
||||
num_train_timesteps: number of timesteps
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- loss: mean loss of C, H, W
|
||||
- metrics:
|
||||
- total_loss: mean loss of C, H, W
|
||||
- ratio_loss: mean ratio loss of C, H, W
|
||||
- model_losses_w: mean loss of w losses of C, H, W
|
||||
- model_losses_l: mean loss of l losses of C, H, W
|
||||
- win_score : mean win score of C, H, W
|
||||
- lose_score : mean lose score of C, H, W
|
||||
|
||||
"""
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
|
||||
ratio_losses = mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
|
||||
metrics = {
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.detach().mean().item(),
|
||||
"model_losses_w": model_losses_w.detach().mean().item(),
|
||||
"model_losses_l": model_losses_l.detach().mean().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(),
|
||||
}
|
||||
|
||||
return loss, metrics
|
||||
|
||||
|
||||
"""
|
||||
##########################################
|
||||
|
||||
@@ -359,7 +359,7 @@ def denoise(
|
||||
|
||||
|
||||
# region train
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) -> torch.FloatTensor:
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
@@ -390,7 +390,7 @@ def compute_density_for_timestep_sampling(
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas) -> torch.Tensor:
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
@@ -407,35 +407,43 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
def get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""
|
||||
Returns:
|
||||
tuple[
|
||||
noisy_model_input: noisy at sigma applied to latent
|
||||
timesteps: timesteps betweeen 1.0 and 1000.0
|
||||
sigmas: sigmas between 0.0 and 1.0
|
||||
]
|
||||
"""
|
||||
bsz, _, h, w = latents.shape
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
num_timesteps: int = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
sigma = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
sigma = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = sigmas * num_timesteps
|
||||
timestep = sigma * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
sigma = torch.randn(bsz, device=device)
|
||||
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigma = sigma.sigmoid()
|
||||
sigma = (sigma * shift) / (1 + (shift - 1) * sigma)
|
||||
timestep = sigma * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigma = torch.randn(bsz, device=device)
|
||||
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigma = sigma.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
sigma = time_shift(mu, 1.0, sigma)
|
||||
timestep = sigma * num_timesteps
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -447,28 +455,29 @@ def get_noisy_model_input_and_timesteps(
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
timestep: torch.Tensor = noise_scheduler.timesteps[indices].to(device=device)
|
||||
sigma = get_sigmas(noise_scheduler, timestep, device, n_dim=latents.ndim, dtype=dtype)
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
sigma = sigma.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
assert isinstance(args.ip_noise_gamma, float)
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
noisy_model_input = (1.0 - sigma) * latents + sigma * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
noisy_model_input = (1.0 - sigma) * latents + sigma * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
return noisy_model_input.to(dtype), timestep.to(dtype), sigma
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
def apply_model_prediction_type(args, model_pred: torch.FloatTensor, noisy_model_input, sigmas):
|
||||
weighting = None
|
||||
if args.model_prediction_type == "raw":
|
||||
pass
|
||||
|
||||
@@ -347,7 +347,7 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
|
||||
return img_ids
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
def unpack_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
"""
|
||||
|
||||
@@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling(
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
@@ -1700,10 +1700,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
latents = image_info.latents_flipped
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
|
||||
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
|
||||
@@ -1715,12 +1719,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
latents = torch.FloatTensor(latents)
|
||||
if alpha_mask is not None:
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
|
||||
|
||||
image = None
|
||||
|
||||
images.append(image)
|
||||
latents_list.append(latents)
|
||||
alpha_mask_list.append(alpha_mask)
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
else:
|
||||
if isinstance(image_info, ImageSetInfo):
|
||||
for absolute_path in image_info.absolute_paths:
|
||||
@@ -2543,11 +2551,11 @@ class ControlNetDataset(BaseDataset):
|
||||
subset.token_warmup_min,
|
||||
subset.token_warmup_step,
|
||||
resize_interpolation=subset.resize_interpolation,
|
||||
subset.preference,
|
||||
subset.preference_caption_prefix,
|
||||
subset.preference_caption_suffix,
|
||||
subset.non_preference_caption_prefix,
|
||||
subset.non_preference_caption_suffix,
|
||||
preference=subset.preference,
|
||||
preference_caption_prefix=subset.preference_caption_prefix,
|
||||
preference_caption_suffix=subset.preference_caption_suffix,
|
||||
non_preference_caption_prefix=subset.non_preference_caption_prefix,
|
||||
non_preference_caption_suffix=subset.non_preference_caption_suffix,
|
||||
)
|
||||
db_subsets.append(db_subset)
|
||||
|
||||
|
||||
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
@@ -323,7 +323,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
@@ -389,7 +389,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
return model_pred, noisy_model_input, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
164
tests/library/test_custom_train_functions_diffusion_dpo.py
Normal file
164
tests/library/test_custom_train_functions_diffusion_dpo.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
from unittest.mock import Mock
|
||||
|
||||
from library.custom_train_functions import diffusion_dpo_loss
|
||||
|
||||
def test_diffusion_dpo_loss_basic():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Mock the call_unet and apply_loss functions
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
beta_dpo = 0.1
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta_dpo)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check expected metrics are present
|
||||
expected_keys = ["total_loss", "raw_model_loss", "ref_loss", "implicit_acc"]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
# Verify mocks were called correctly
|
||||
call_unet.assert_called_once()
|
||||
apply_loss.assert_called_once_with(mock_unet_output)
|
||||
|
||||
def test_diffusion_dpo_loss_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 4, 32, 32), # Small tensor
|
||||
(4, 16, 64, 64), # Medium tensor
|
||||
(6, 32, 128, 128), # Larger tensor
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
|
||||
# Create mocks
|
||||
mock_unet_output = torch.rand(*shape)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(*shape)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
|
||||
# The result should be a scalar tensor
|
||||
assert result.shape == torch.Size([shape[0] // 2])
|
||||
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert isinstance(val, float)
|
||||
|
||||
def test_diffusion_dpo_loss_beta_values():
|
||||
batch_size = 2
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Create consistent mock returns
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
|
||||
# Test with different beta values
|
||||
beta_values = [0.0, 0.1, 1.0, 10.0]
|
||||
results = []
|
||||
|
||||
for beta in beta_values:
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta)
|
||||
results.append(result.item())
|
||||
|
||||
# With increasing beta, results should be different
|
||||
# This test checks that beta affects the output
|
||||
assert len(set(results)) > 1, "Different beta values should produce different results"
|
||||
|
||||
def test_diffusion_dpo_implicit_acc():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create controlled test data where winners have lower loss
|
||||
w_loss = torch.ones(batch_size//2, channels, height, width) * 0.2
|
||||
l_loss = torch.ones(batch_size//2, channels, height, width) * 0.8
|
||||
loss = torch.cat([w_loss, l_loss], dim=0)
|
||||
|
||||
# Make the reference loss similar but with less difference
|
||||
ref_w_loss = torch.ones(batch_size//2, channels, height, width) * 0.3
|
||||
ref_l_loss = torch.ones(batch_size//2, channels, height, width) * 0.7
|
||||
ref_loss = torch.cat([ref_w_loss, ref_l_loss], dim=0)
|
||||
|
||||
call_unet = Mock(return_value=torch.zeros_like(loss)) # Dummy, won't be used
|
||||
apply_loss = Mock(return_value=ref_loss)
|
||||
|
||||
# With a positive beta, model_diff > ref_diff should lead to high implicit accuracy
|
||||
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 1.0)
|
||||
|
||||
# Implicit accuracy should be high (model correctly identifies preferences)
|
||||
assert metrics["implicit_acc"] > 0.5
|
||||
|
||||
def test_diffusion_dpo_gradient_flow():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create loss tensor that requires gradients
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
|
||||
# Create mock outputs
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
# Compute loss
|
||||
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
|
||||
# Check that gradients flow
|
||||
result.mean().backward()
|
||||
|
||||
# Verify gradients flowed through
|
||||
assert loss.grad is not None
|
||||
|
||||
def test_diffusion_dpo_no_ref_grad():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
|
||||
# Set up mock that tracks if it was called with no_grad
|
||||
mock_unet_output = torch.rand(batch_size, channels, height, width)
|
||||
call_unet = Mock(return_value=mock_unet_output)
|
||||
|
||||
mock_loss_output = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
apply_loss = Mock(return_value=mock_loss_output)
|
||||
|
||||
# Run function
|
||||
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
|
||||
result.mean().backward()
|
||||
|
||||
# Check that the reference loss has no gradients (was computed with torch.no_grad())
|
||||
# This is a bit tricky to test directly, but we can verify call_unet was called
|
||||
call_unet.assert_called_once()
|
||||
apply_loss.assert_called_once()
|
||||
|
||||
# The mock_loss_output should not receive gradients as it's used inside torch.no_grad()
|
||||
assert mock_loss_output.grad is None
|
||||
116
tests/library/test_custom_train_functions_mapo.py
Normal file
116
tests/library/test_custom_train_functions_mapo.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from library.custom_train_functions import mapo_loss
|
||||
|
||||
|
||||
def test_mapo_loss_basic():
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create dummy loss tensor with shape [B, C, H, W]
|
||||
loss = torch.rand(batch_size, channels, height, width)
|
||||
mapo_weight = 0.5
|
||||
|
||||
result, metrics = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check return types
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(metrics, dict)
|
||||
|
||||
# Check required metrics are present
|
||||
expected_keys = ["total_loss", "ratio_loss", "model_losses_w", "model_losses_l", "win_score", "lose_score"]
|
||||
for key in expected_keys:
|
||||
assert key in metrics
|
||||
assert isinstance(metrics[key], float)
|
||||
|
||||
|
||||
def test_mapo_loss_different_shapes():
|
||||
# Test with different tensor shapes
|
||||
shapes = [
|
||||
(2, 4, 32, 32), # Small tensor
|
||||
(4, 16, 64, 64), # Medium tensor
|
||||
(6, 32, 128, 128), # Larger tensor
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
loss = torch.rand(*shape)
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# The result should be a scalar tensor
|
||||
assert result.shape == torch.Size([])
|
||||
|
||||
# All metrics should be scalars
|
||||
for val in metrics.values():
|
||||
assert np.isscalar(val)
|
||||
|
||||
|
||||
def test_mapo_loss_with_zero_weight():
|
||||
loss = torch.rand(4, 3, 64, 64)
|
||||
result, metrics = mapo_loss(loss, 0.0)
|
||||
|
||||
# With zero mapo_weight, ratio_loss should be zero
|
||||
assert metrics["ratio_loss"] == 0.0
|
||||
|
||||
# result should be equal to mean of model_losses_w
|
||||
assert torch.isclose(result, torch.tensor(metrics["model_losses_w"]))
|
||||
|
||||
|
||||
def test_mapo_loss_with_different_timesteps():
|
||||
loss = torch.rand(4, 4, 32, 32)
|
||||
|
||||
# Test with different timestep values
|
||||
timesteps = [1, 10, 100, 1000]
|
||||
|
||||
for ts in timesteps:
|
||||
result, metrics = mapo_loss(loss, 0.5, ts)
|
||||
|
||||
# Check that the results are different for different timesteps
|
||||
if ts > 1:
|
||||
result_prev, metrics_prev = mapo_loss(loss, 0.5, ts // 10)
|
||||
# Log odds should be affected by timesteps, so ratio_loss should change
|
||||
assert metrics["ratio_loss"] != metrics_prev["ratio_loss"]
|
||||
|
||||
|
||||
def test_mapo_loss_win_loss_scores():
|
||||
# Create a controlled input where win losses are lower than lose losses
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create losses where winning examples have lower loss
|
||||
w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1
|
||||
l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9
|
||||
|
||||
# Concatenate to create the full loss tensor
|
||||
loss = torch.cat([w_loss, l_loss], dim=0)
|
||||
|
||||
# Run the function
|
||||
result, metrics = mapo_loss(loss, 0.5)
|
||||
|
||||
# Win score should be higher than lose score (better performance)
|
||||
assert metrics["win_score"] > metrics["lose_score"]
|
||||
|
||||
# Model losses for winners should be lower
|
||||
assert metrics["model_losses_w"] < metrics["model_losses_l"]
|
||||
|
||||
|
||||
def test_mapo_loss_gradient_flow():
|
||||
# Test that gradients flow through the loss function
|
||||
batch_size = 4
|
||||
channels = 4
|
||||
height, width = 64, 64
|
||||
|
||||
# Create a loss tensor that requires grad
|
||||
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
|
||||
mapo_weight = 0.5
|
||||
|
||||
# Compute loss
|
||||
result, _ = mapo_loss(loss, mapo_weight)
|
||||
|
||||
# Check that gradients flow
|
||||
result.backward()
|
||||
|
||||
# If gradients flow, loss.grad should not be None
|
||||
assert loss.grad is not None
|
||||
121
train_network.py
121
train_network.py
@@ -43,6 +43,8 @@ from library.custom_train_functions import (
|
||||
add_v_prediction_like_loss,
|
||||
apply_debiased_estimation,
|
||||
apply_masked_loss,
|
||||
diffusion_dpo_loss,
|
||||
mapo_loss
|
||||
)
|
||||
from library.utils import setup_logging, add_logging_arguments
|
||||
|
||||
@@ -255,18 +257,18 @@ class NetworkTrainer:
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
args: argparse.Namespace,
|
||||
accelerator: Accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
latents: torch.FloatTensor,
|
||||
batch: dict[str, torch.Tensor],
|
||||
text_encoder_conds,
|
||||
unet,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
weight_dtype: torch.dtype,
|
||||
train_unet: bool,
|
||||
is_train=True,
|
||||
):
|
||||
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
|
||||
@@ -321,9 +323,9 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
return noise_pred, noisy_latents, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -380,10 +382,12 @@ class NetworkTrainer:
|
||||
is_train=True,
|
||||
train_text_encoder=True,
|
||||
train_unet=True,
|
||||
) -> torch.Tensor:
|
||||
multipliers=1.0,
|
||||
) -> tuple[torch.Tensor, dict[str, float | int]]:
|
||||
"""
|
||||
Process a batch for the network
|
||||
"""
|
||||
metrics: dict[str, float | int] = {}
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
@@ -446,7 +450,7 @@ class NetworkTrainer:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
noise_pred, noisy_latents, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
@@ -468,85 +472,34 @@ class NetworkTrainer:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
if args.beta_dpo is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
model_loss_w, model_loss_l = model_loss.chunk(2)
|
||||
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
|
||||
model_diff = model_loss_w - model_loss_l
|
||||
|
||||
# ref loss
|
||||
with torch.no_grad():
|
||||
# disable network for reference
|
||||
def call_unet():
|
||||
accelerator.unwrap_model(network).set_multiplier(0.0)
|
||||
ref_noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
with accelerator.autocast():
|
||||
ref_noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(1.0)
|
||||
return ref_noise_pred
|
||||
def apply_loss(ref_noise_pred):
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
ref_loss = train_util.conditional_loss(
|
||||
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
|
||||
)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
ref_loss = apply_masked_loss(ref_loss, batch)
|
||||
return ref_loss
|
||||
|
||||
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
||||
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
||||
ref_diff = ref_losses_w - ref_losses_l
|
||||
raw_ref_loss = ref_loss.mean()
|
||||
|
||||
# reset network multipliers
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
scale_term = -0.5 * args.beta_dpo
|
||||
inside_term = scale_term * (model_diff - ref_diff)
|
||||
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
|
||||
|
||||
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
|
||||
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": model_loss.detach().mean().item(),
|
||||
"raw_model_loss": raw_model_loss.detach().mean().item(),
|
||||
"ref_loss": raw_ref_loss.detach().item(),
|
||||
"implicit_acc": implicit_acc.detach().item(),
|
||||
}, step=global_step)
|
||||
loss, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, args.beta_dpo)
|
||||
elif args.mapo_weight is not None:
|
||||
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
|
||||
|
||||
snr = 0.5
|
||||
model_losses_w, model_losses_l = model_loss.chunk(2)
|
||||
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
|
||||
snr * model_losses_l
|
||||
) / (torch.exp(snr * model_losses_l) - 1)
|
||||
|
||||
# Ratio loss.
|
||||
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
|
||||
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
|
||||
ratio_losses = args.mapo_weight * ratio
|
||||
|
||||
# Full MaPO loss
|
||||
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
|
||||
|
||||
accelerator.log({
|
||||
"total_loss": loss.detach().mean().item(),
|
||||
"ratio_loss": -ratio_losses.mean().detach().item(),
|
||||
"model_losses_w": model_losses_w.mean().detach().item(),
|
||||
"model_losses_l": model_losses_l.mean().detach().item(),
|
||||
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item(),
|
||||
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
|
||||
.mean()
|
||||
.detach()
|
||||
.item()
|
||||
}, step=global_step)
|
||||
loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps)
|
||||
else:
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
@@ -555,7 +508,7 @@ class NetworkTrainer:
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
|
||||
return loss.mean()
|
||||
return loss.mean(), metrics
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -1482,7 +1435,7 @@ class NetworkTrainer:
|
||||
# preprocess batch for each model
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||
|
||||
loss = self.process_batch(
|
||||
loss, batch_metrics = self.process_batch(
|
||||
batch,
|
||||
text_encoders,
|
||||
unet,
|
||||
@@ -1584,7 +1537,7 @@ class NetworkTrainer:
|
||||
mean_grad_norm,
|
||||
mean_combined_norm,
|
||||
)
|
||||
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||
self.step_logging(accelerator, {**logs, **batch_metrics}, global_step, epoch + 1)
|
||||
|
||||
# VALIDATION PER STEP: global_step is already incremented
|
||||
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
||||
|
||||
Reference in New Issue
Block a user