Update PO cached latents, move out functions, update calls

This commit is contained in:
rockerBOO
2025-04-27 17:38:50 -04:00
parent 74529743d4
commit d22c827544
11 changed files with 480 additions and 129 deletions

View File

@@ -336,24 +336,24 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_noise_pred_and_target(
self,
args,
accelerator,
args: argparse.Namespace,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
latents: torch.FloatTensor,
batch: dict[str, torch.Tensor],
text_encoder_conds,
unet: flux_models.Flux,
unet,
network,
weight_dtype,
train_unet,
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps(
noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)
@@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

View File

@@ -4,7 +4,7 @@ import argparse
import random
import re
from torch.types import Number
from typing import List, Optional, Union
from typing import List, Optional, Union, Callable
from .utils import setup_logging
setup_logging()
@@ -502,6 +502,106 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
loss = loss * mask_image
return loss
def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float):
"""
DPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
call_unet: function to call unet
apply_loss: function to apply loss
beta_dpo: beta_dpo weight
Returns:
tuple:
- loss: mean loss of C, H, W
- metrics:
- total_loss: mean loss of C, H, W
- raw_model_loss: mean loss of C, H, W
- ref_loss: mean loss of C, H, W
- implicit_acc: accumulated implicit of C, H, W
"""
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
model_loss_w, model_loss_l = model_loss.chunk(2)
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
model_diff = model_loss_w - model_loss_l
# ref loss
with torch.no_grad():
ref_noise_pred = call_unet()
ref_loss = apply_loss(ref_noise_pred)
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
scale_term = -0.5 * beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
metrics = {
"total_loss": model_loss.detach().mean().item(),
"raw_model_loss": raw_model_loss.detach().mean().item(),
"ref_loss": raw_ref_loss.detach().item(),
"implicit_acc": implicit_acc.detach().item(),
}
return loss, metrics
def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]:
"""
MaPO loss
Args:
loss: pairs of w, l losses B//2, C, H, W
mapo_loss: mapo weight
num_train_timesteps: number of timesteps
Returns:
tuple:
- loss: mean loss of C, H, W
- metrics:
- total_loss: mean loss of C, H, W
- ratio_loss: mean ratio loss of C, H, W
- model_losses_w: mean loss of w losses of C, H, W
- model_losses_l: mean loss of l losses of C, H, W
- win_score : mean win score of C, H, W
- lose_score : mean lose score of C, H, W
"""
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
snr = 0.5
model_losses_w, model_losses_l = model_loss.chunk(2)
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
snr * model_losses_l
) / (torch.exp(snr * model_losses_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps)
ratio_losses = mapo_weight * ratio
# Full MaPO loss
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
metrics = {
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.detach().mean().item(),
"model_losses_w": model_losses_w.detach().mean().item(),
"model_losses_l": model_losses_l.detach().mean().item(),
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(),
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(),
}
return loss, metrics
"""
##########################################

View File

@@ -359,7 +359,7 @@ def denoise(
# region train
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) -> torch.FloatTensor:
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
@@ -390,7 +390,7 @@ def compute_density_for_timestep_sampling(
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas) -> torch.Tensor:
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
@@ -407,35 +407,43 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def get_noisy_model_input_and_timesteps(
def get_noisy_model_input_and_timestep(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Returns:
tuple[
noisy_model_input: noisy at sigma applied to latent
timesteps: timesteps betweeen 1.0 and 1000.0
sigmas: sigmas between 0.0 and 1.0
]
"""
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
num_timesteps: int = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigma = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
sigmas = torch.rand((bsz,), device=device)
sigma = torch.rand((bsz,), device=device)
timesteps = sigmas * num_timesteps
timestep = sigma * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
sigma = torch.randn(bsz, device=device)
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
sigma = sigma.sigmoid()
sigma = (sigma * shift) / (1 + (shift - 1) * sigma)
timestep = sigma * num_timesteps
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigma = torch.randn(bsz, device=device)
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
sigma = sigma.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
sigma = time_shift(mu, 1.0, sigma)
timestep = sigma * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -447,28 +455,29 @@ def get_noisy_model_input_and_timesteps(
mode_scale=args.mode_scale,
)
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
timestep: torch.Tensor = noise_scheduler.timesteps[indices].to(device=device)
sigma = get_sigmas(noise_scheduler, timestep, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
sigma = sigma.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
assert isinstance(args.ip_noise_gamma, float)
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
noisy_model_input = (1.0 - sigma) * latents + sigma * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
noisy_model_input = (1.0 - sigma) * latents + sigma * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
return noisy_model_input.to(dtype), timestep.to(dtype), sigma
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
def apply_model_prediction_type(args, model_pred: torch.FloatTensor, noisy_model_input, sigmas):
weighting = None
if args.model_prediction_type == "raw":
pass

View File

@@ -347,7 +347,7 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi
return img_ids
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
def unpack_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""

View File

@@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling(
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
"""Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

View File

@@ -1700,10 +1700,14 @@ class BaseDataset(torch.utils.data.Dataset):
latents = image_info.latents_flipped
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
@@ -1715,12 +1719,16 @@ class BaseDataset(torch.utils.data.Dataset):
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
target_size = (latents.shape[2] * 8, latents.shape[1] * 8)
image = None
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0])))
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
else:
if isinstance(image_info, ImageSetInfo):
for absolute_path in image_info.absolute_paths:
@@ -2543,11 +2551,11 @@ class ControlNetDataset(BaseDataset):
subset.token_warmup_min,
subset.token_warmup_step,
resize_interpolation=subset.resize_interpolation,
subset.preference,
subset.preference_caption_prefix,
subset.preference_caption_suffix,
subset.non_preference_caption_prefix,
subset.non_preference_caption_suffix,
preference=subset.preference,
preference_caption_prefix=subset.preference_caption_prefix,
preference_caption_suffix=subset.preference_caption_suffix,
non_preference_caption_prefix=subset.non_preference_caption_prefix,
non_preference_caption_suffix=subset.non_preference_caption_suffix,
)
db_subsets.append(db_subset)

View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

View File

@@ -323,7 +323,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
weight_dtype,
train_unet,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
@@ -389,7 +389,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, weighting
return model_pred, noisy_model_input, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

View File

@@ -0,0 +1,164 @@
import torch
from unittest.mock import Mock
from library.custom_train_functions import diffusion_dpo_loss
def test_diffusion_dpo_loss_basic():
batch_size = 4
channels = 4
height, width = 64, 64
# Create dummy loss tensor
loss = torch.rand(batch_size, channels, height, width)
# Mock the call_unet and apply_loss functions
mock_unet_output = torch.rand(batch_size, channels, height, width)
call_unet = Mock(return_value=mock_unet_output)
mock_loss_output = torch.rand(batch_size, channels, height, width)
apply_loss = Mock(return_value=mock_loss_output)
beta_dpo = 0.1
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta_dpo)
# Check return types
assert isinstance(result, torch.Tensor)
assert isinstance(metrics, dict)
# Check expected metrics are present
expected_keys = ["total_loss", "raw_model_loss", "ref_loss", "implicit_acc"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], float)
# Verify mocks were called correctly
call_unet.assert_called_once()
apply_loss.assert_called_once_with(mock_unet_output)
def test_diffusion_dpo_loss_shapes():
# Test with different tensor shapes
shapes = [
(2, 4, 32, 32), # Small tensor
(4, 16, 64, 64), # Medium tensor
(6, 32, 128, 128), # Larger tensor
]
for shape in shapes:
loss = torch.rand(*shape)
# Create mocks
mock_unet_output = torch.rand(*shape)
call_unet = Mock(return_value=mock_unet_output)
mock_loss_output = torch.rand(*shape)
apply_loss = Mock(return_value=mock_loss_output)
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
# The result should be a scalar tensor
assert result.shape == torch.Size([shape[0] // 2])
# All metrics should be scalars
for val in metrics.values():
assert isinstance(val, float)
def test_diffusion_dpo_loss_beta_values():
batch_size = 2
channels = 4
height, width = 64, 64
loss = torch.rand(batch_size, channels, height, width)
# Create consistent mock returns
mock_unet_output = torch.rand(batch_size, channels, height, width)
mock_loss_output = torch.rand(batch_size, channels, height, width)
# Test with different beta values
beta_values = [0.0, 0.1, 1.0, 10.0]
results = []
for beta in beta_values:
call_unet = Mock(return_value=mock_unet_output)
apply_loss = Mock(return_value=mock_loss_output)
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta)
results.append(result.item())
# With increasing beta, results should be different
# This test checks that beta affects the output
assert len(set(results)) > 1, "Different beta values should produce different results"
def test_diffusion_dpo_implicit_acc():
batch_size = 4
channels = 4
height, width = 64, 64
# Create controlled test data where winners have lower loss
w_loss = torch.ones(batch_size//2, channels, height, width) * 0.2
l_loss = torch.ones(batch_size//2, channels, height, width) * 0.8
loss = torch.cat([w_loss, l_loss], dim=0)
# Make the reference loss similar but with less difference
ref_w_loss = torch.ones(batch_size//2, channels, height, width) * 0.3
ref_l_loss = torch.ones(batch_size//2, channels, height, width) * 0.7
ref_loss = torch.cat([ref_w_loss, ref_l_loss], dim=0)
call_unet = Mock(return_value=torch.zeros_like(loss)) # Dummy, won't be used
apply_loss = Mock(return_value=ref_loss)
# With a positive beta, model_diff > ref_diff should lead to high implicit accuracy
result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 1.0)
# Implicit accuracy should be high (model correctly identifies preferences)
assert metrics["implicit_acc"] > 0.5
def test_diffusion_dpo_gradient_flow():
batch_size = 4
channels = 4
height, width = 64, 64
# Create loss tensor that requires gradients
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
# Create mock outputs
mock_unet_output = torch.rand(batch_size, channels, height, width)
call_unet = Mock(return_value=mock_unet_output)
mock_loss_output = torch.rand(batch_size, channels, height, width)
apply_loss = Mock(return_value=mock_loss_output)
# Compute loss
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
# Check that gradients flow
result.mean().backward()
# Verify gradients flowed through
assert loss.grad is not None
def test_diffusion_dpo_no_ref_grad():
batch_size = 4
channels = 4
height, width = 64, 64
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
# Set up mock that tracks if it was called with no_grad
mock_unet_output = torch.rand(batch_size, channels, height, width)
call_unet = Mock(return_value=mock_unet_output)
mock_loss_output = torch.rand(batch_size, channels, height, width, requires_grad=True)
apply_loss = Mock(return_value=mock_loss_output)
# Run function
result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1)
result.mean().backward()
# Check that the reference loss has no gradients (was computed with torch.no_grad())
# This is a bit tricky to test directly, but we can verify call_unet was called
call_unet.assert_called_once()
apply_loss.assert_called_once()
# The mock_loss_output should not receive gradients as it's used inside torch.no_grad()
assert mock_loss_output.grad is None

View File

@@ -0,0 +1,116 @@
import torch
import numpy as np
from library.custom_train_functions import mapo_loss
def test_mapo_loss_basic():
batch_size = 4
channels = 4
height, width = 64, 64
# Create dummy loss tensor with shape [B, C, H, W]
loss = torch.rand(batch_size, channels, height, width)
mapo_weight = 0.5
result, metrics = mapo_loss(loss, mapo_weight)
# Check return types
assert isinstance(result, torch.Tensor)
assert isinstance(metrics, dict)
# Check required metrics are present
expected_keys = ["total_loss", "ratio_loss", "model_losses_w", "model_losses_l", "win_score", "lose_score"]
for key in expected_keys:
assert key in metrics
assert isinstance(metrics[key], float)
def test_mapo_loss_different_shapes():
# Test with different tensor shapes
shapes = [
(2, 4, 32, 32), # Small tensor
(4, 16, 64, 64), # Medium tensor
(6, 32, 128, 128), # Larger tensor
]
for shape in shapes:
loss = torch.rand(*shape)
result, metrics = mapo_loss(loss, 0.5)
# The result should be a scalar tensor
assert result.shape == torch.Size([])
# All metrics should be scalars
for val in metrics.values():
assert np.isscalar(val)
def test_mapo_loss_with_zero_weight():
loss = torch.rand(4, 3, 64, 64)
result, metrics = mapo_loss(loss, 0.0)
# With zero mapo_weight, ratio_loss should be zero
assert metrics["ratio_loss"] == 0.0
# result should be equal to mean of model_losses_w
assert torch.isclose(result, torch.tensor(metrics["model_losses_w"]))
def test_mapo_loss_with_different_timesteps():
loss = torch.rand(4, 4, 32, 32)
# Test with different timestep values
timesteps = [1, 10, 100, 1000]
for ts in timesteps:
result, metrics = mapo_loss(loss, 0.5, ts)
# Check that the results are different for different timesteps
if ts > 1:
result_prev, metrics_prev = mapo_loss(loss, 0.5, ts // 10)
# Log odds should be affected by timesteps, so ratio_loss should change
assert metrics["ratio_loss"] != metrics_prev["ratio_loss"]
def test_mapo_loss_win_loss_scores():
# Create a controlled input where win losses are lower than lose losses
batch_size = 4
channels = 4
height, width = 64, 64
# Create losses where winning examples have lower loss
w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1
l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9
# Concatenate to create the full loss tensor
loss = torch.cat([w_loss, l_loss], dim=0)
# Run the function
result, metrics = mapo_loss(loss, 0.5)
# Win score should be higher than lose score (better performance)
assert metrics["win_score"] > metrics["lose_score"]
# Model losses for winners should be lower
assert metrics["model_losses_w"] < metrics["model_losses_l"]
def test_mapo_loss_gradient_flow():
# Test that gradients flow through the loss function
batch_size = 4
channels = 4
height, width = 64, 64
# Create a loss tensor that requires grad
loss = torch.rand(batch_size, channels, height, width, requires_grad=True)
mapo_weight = 0.5
# Compute loss
result, _ = mapo_loss(loss, mapo_weight)
# Check that gradients flow
result.backward()
# If gradients flow, loss.grad should not be None
assert loss.grad is not None

View File

@@ -43,6 +43,8 @@ from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_debiased_estimation,
apply_masked_loss,
diffusion_dpo_loss,
mapo_loss
)
from library.utils import setup_logging, add_logging_arguments
@@ -255,18 +257,18 @@ class NetworkTrainer:
def get_noise_pred_and_target(
self,
args,
accelerator,
args: argparse.Namespace,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
latents: torch.FloatTensor,
batch: dict[str, torch.Tensor],
text_encoder_conds,
unet,
network,
weight_dtype,
train_unet,
weight_dtype: torch.dtype,
train_unet: bool,
is_train=True,
):
) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]:
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
@@ -321,9 +323,9 @@ class NetworkTrainer:
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
return noise_pred, target, timesteps, None
return noise_pred, noisy_latents, target, timesteps, None
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
@@ -380,10 +382,12 @@ class NetworkTrainer:
is_train=True,
train_text_encoder=True,
train_unet=True,
) -> torch.Tensor:
multipliers=1.0,
) -> tuple[torch.Tensor, dict[str, float | int]]:
"""
Process a batch for the network
"""
metrics: dict[str, float | int] = {}
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
@@ -446,7 +450,7 @@ class NetworkTrainer:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
noise_pred, noisy_latents, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
accelerator,
noise_scheduler,
@@ -468,85 +472,34 @@ class NetworkTrainer:
loss = apply_masked_loss(loss, batch)
if args.beta_dpo is not None:
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
model_loss_w, model_loss_l = model_loss.chunk(2)
raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean())
model_diff = model_loss_w - model_loss_l
# ref loss
with torch.no_grad():
# disable network for reference
def call_unet():
accelerator.unwrap_model(network).set_multiplier(0.0)
ref_noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
with accelerator.autocast():
ref_noise_pred = self.call_unet(
args,
accelerator,
unet,
noisy_latents.requires_grad_(train_unet),
timesteps,
text_encoder_conds,
batch,
weight_dtype,
)
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(1.0)
return ref_noise_pred
def apply_loss(ref_noise_pred):
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
ref_loss = train_util.conditional_loss(
ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
ref_loss = apply_masked_loss(ref_loss, batch)
return ref_loss
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
ref_diff = ref_losses_w - ref_losses_l
raw_ref_loss = ref_loss.mean()
# reset network multipliers
accelerator.unwrap_model(network).set_multiplier(multipliers)
scale_term = -0.5 * args.beta_dpo
inside_term = scale_term * (model_diff - ref_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0)
implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0)
accelerator.log({
"total_loss": model_loss.detach().mean().item(),
"raw_model_loss": raw_model_loss.detach().mean().item(),
"ref_loss": raw_ref_loss.detach().item(),
"implicit_acc": implicit_acc.detach().item(),
}, step=global_step)
loss, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, args.beta_dpo)
elif args.mapo_weight is not None:
model_loss = loss.mean(dim=list(range(1, len(loss.shape))))
snr = 0.5
model_losses_w, model_losses_l = model_loss.chunk(2)
log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - (
snr * model_losses_l
) / (torch.exp(snr * model_losses_l) - 1)
# Ratio loss.
# By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process.
ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps)
ratio_losses = args.mapo_weight * ratio
# Full MaPO loss
loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape))))
accelerator.log({
"total_loss": loss.detach().mean().item(),
"ratio_loss": -ratio_losses.mean().detach().item(),
"model_losses_w": model_losses_w.mean().detach().item(),
"model_losses_l": model_losses_l.mean().detach().item(),
"win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1))
.mean()
.detach()
.item(),
"lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1))
.mean()
.detach()
.item()
}, step=global_step)
loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps)
else:
loss = loss.mean([1, 2, 3])
@@ -555,7 +508,7 @@ class NetworkTrainer:
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
return loss.mean()
return loss.mean(), metrics
def train(self, args):
session_id = random.randint(0, 2**32)
@@ -1482,7 +1435,7 @@ class NetworkTrainer:
# preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch(
loss, batch_metrics = self.process_batch(
batch,
text_encoders,
unet,
@@ -1584,7 +1537,7 @@ class NetworkTrainer:
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, logs, global_step, epoch + 1)
self.step_logging(accelerator, {**logs, **batch_metrics}, global_step, epoch + 1)
# VALIDATION PER STEP: global_step is already incremented
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...