diff --git a/flux_train_network.py b/flux_train_network.py index def44155..c619afac 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -336,24 +336,24 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def get_noise_pred_and_target( self, - args, - accelerator, + args: argparse.Namespace, + accelerator: Accelerator, noise_scheduler, - latents, - batch, + latents: torch.FloatTensor, + batch: dict[str, torch.Tensor], text_encoder_conds, - unet: flux_models.Flux, + unet, network, - weight_dtype, - train_unet, + weight_dtype: torch.dtype, + train_unet: bool, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timesteps( + noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep( args, noise_scheduler, latents, noise, accelerator.device, weight_dtype ) @@ -448,7 +448,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): ) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, noisy_model_input, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index ad3e69ff..18ad8234 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -4,7 +4,7 @@ import argparse import random import re from torch.types import Number -from typing import List, Optional, Union +from typing import List, Optional, Union, Callable from .utils import setup_logging setup_logging() @@ -502,6 +502,106 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: loss = loss * mask_image return loss +def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float): + """ + DPO loss + + Args: + loss: pairs of w, l losses B//2, C, H, W + call_unet: function to call unet + apply_loss: function to apply loss + beta_dpo: beta_dpo weight + + Returns: + tuple: + - loss: mean loss of C, H, W + - metrics: + - total_loss: mean loss of C, H, W + - raw_model_loss: mean loss of C, H, W + - ref_loss: mean loss of C, H, W + - implicit_acc: accumulated implicit of C, H, W + + """ + + model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) + model_loss_w, model_loss_l = model_loss.chunk(2) + raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean()) + model_diff = model_loss_w - model_loss_l + + # ref loss + with torch.no_grad(): + ref_noise_pred = call_unet() + ref_loss = apply_loss(ref_noise_pred) + ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + ref_diff = ref_losses_w - ref_losses_l + raw_ref_loss = ref_loss.mean() + + + scale_term = -0.5 * beta_dpo + inside_term = scale_term * (model_diff - ref_diff) + loss = -1 * torch.nn.functional.logsigmoid(inside_term) + + implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) + implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) + + metrics = { + "total_loss": model_loss.detach().mean().item(), + "raw_model_loss": raw_model_loss.detach().mean().item(), + "ref_loss": raw_ref_loss.detach().item(), + "implicit_acc": implicit_acc.detach().item(), + } + + return loss, metrics + +def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) -> tuple[torch.Tensor, dict[str, int | float]]: + """ + MaPO loss + + Args: + loss: pairs of w, l losses B//2, C, H, W + mapo_loss: mapo weight + num_train_timesteps: number of timesteps + + Returns: + tuple: + - loss: mean loss of C, H, W + - metrics: + - total_loss: mean loss of C, H, W + - ratio_loss: mean ratio loss of C, H, W + - model_losses_w: mean loss of w losses of C, H, W + - model_losses_l: mean loss of l losses of C, H, W + - win_score : mean win score of C, H, W + - lose_score : mean lose score of C, H, W + + """ + model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) + + snr = 0.5 + model_losses_w, model_losses_l = model_loss.chunk(2) + log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - ( + snr * model_losses_l + ) / (torch.exp(snr * model_losses_l) - 1) + + # Ratio loss. + # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. + ratio = torch.nn.functional.logsigmoid(log_odds * num_train_timesteps) + ratio_losses = mapo_weight * ratio + + # Full MaPO loss + loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape)))) + + metrics = { + "total_loss": loss.detach().mean().item(), + "ratio_loss": -ratio_losses.detach().mean().item(), + "model_losses_w": model_losses_w.detach().mean().item(), + "model_losses_l": model_losses_l.detach().mean().item(), + "win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(), + "lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(), + } + + return loss, metrics + """ ########################################## diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0e73a01d..b8cf40de 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -359,7 +359,7 @@ def denoise( # region train -def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): +def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) -> torch.FloatTensor: sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(device) timesteps = timesteps.to(device) @@ -390,7 +390,7 @@ def compute_density_for_timestep_sampling( return u -def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas) -> torch.Tensor: """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. @@ -407,35 +407,43 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): return weighting -def get_noisy_model_input_and_timesteps( +def get_noisy_model_input_and_timestep( args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Returns: + tuple[ + noisy_model_input: noisy at sigma applied to latent + timesteps: timesteps betweeen 1.0 and 1000.0 + sigmas: sigmas between 0.0 and 1.0 + ] + """ bsz, _, h, w = latents.shape assert bsz > 0, "Batch size not large enough" - num_timesteps = noise_scheduler.config.num_train_timesteps + num_timesteps: int = noise_scheduler.config.num_train_timesteps if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": # Simple random sigma-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + sigma = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - sigmas = torch.rand((bsz,), device=device) + sigma = torch.rand((bsz,), device=device) - timesteps = sigmas * num_timesteps + timestep = sigma * num_timesteps elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - sigmas = torch.randn(bsz, device=device) - sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling - sigmas = sigmas.sigmoid() - sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) - timesteps = sigmas * num_timesteps + sigma = torch.randn(bsz, device=device) + sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling + sigma = sigma.sigmoid() + sigma = (sigma * shift) / (1 + (shift - 1) * sigma) + timestep = sigma * num_timesteps elif args.timestep_sampling == "flux_shift": - sigmas = torch.randn(bsz, device=device) - sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling - sigmas = sigmas.sigmoid() + sigma = torch.randn(bsz, device=device) + sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling + sigma = sigma.sigmoid() mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size - sigmas = time_shift(mu, 1.0, sigmas) - timesteps = sigmas * num_timesteps + sigma = time_shift(mu, 1.0, sigma) + timestep = sigma * num_timesteps else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -447,28 +455,29 @@ def get_noisy_model_input_and_timesteps( mode_scale=args.mode_scale, ) indices = (u * num_timesteps).long() - timesteps = noise_scheduler.timesteps[indices].to(device=device) - sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) + timestep: torch.Tensor = noise_scheduler.timesteps[indices].to(device=device) + sigma = get_sigmas(noise_scheduler, timestep, device, n_dim=latents.ndim, dtype=dtype) # Broadcast sigmas to latent shape - sigmas = sigmas.view(-1, 1, 1, 1) + sigma = sigma.view(-1, 1, 1, 1) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.ip_noise_gamma: + assert isinstance(args.ip_noise_gamma, float) xi = torch.randn_like(latents, device=latents.device, dtype=dtype) if args.ip_noise_gamma_random_strength: ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) else: ip_noise_gamma = args.ip_noise_gamma - noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) + noisy_model_input = (1.0 - sigma) * latents + sigma * (noise + ip_noise_gamma * xi) else: - noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + noisy_model_input = (1.0 - sigma) * latents + sigma * noise - return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas + return noisy_model_input.to(dtype), timestep.to(dtype), sigma -def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas): +def apply_model_prediction_type(args, model_pred: torch.FloatTensor, noisy_model_input, sigmas): weighting = None if args.model_prediction_type == "raw": pass diff --git a/library/flux_utils.py b/library/flux_utils.py index 8be1d63e..7e8f9f61 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -347,7 +347,7 @@ def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_wi return img_ids -def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor: +def unpack_latents(x: torch.FloatTensor, packed_latent_height: int, packed_latent_width: int) -> torch.FloatTensor: """ x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2 """ diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index c4079884..fe03e8fc 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -895,7 +895,7 @@ def compute_density_for_timestep_sampling( return u -def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): +def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor: """Computes loss weighting scheme for SD3 training. Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. diff --git a/library/train_util.py b/library/train_util.py index 08bb7dfd..178b0948 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1700,10 +1700,14 @@ class BaseDataset(torch.utils.data.Dataset): latents = image_info.latents_flipped alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) + target_size = (latents.shape[2] * 8, latents.shape[1] * 8) image = None images.append(image) latents_list.append(latents) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) @@ -1715,12 +1719,16 @@ class BaseDataset(torch.utils.data.Dataset): latents = torch.FloatTensor(latents) if alpha_mask is not None: alpha_mask = torch.FloatTensor(alpha_mask) + target_size = (latents.shape[2] * 8, latents.shape[1] * 8) image = None images.append(image) latents_list.append(latents) alpha_mask_list.append(alpha_mask) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_ltrb[1]), int(crop_ltrb[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) else: if isinstance(image_info, ImageSetInfo): for absolute_path in image_info.absolute_paths: @@ -2543,11 +2551,11 @@ class ControlNetDataset(BaseDataset): subset.token_warmup_min, subset.token_warmup_step, resize_interpolation=subset.resize_interpolation, - subset.preference, - subset.preference_caption_prefix, - subset.preference_caption_suffix, - subset.non_preference_caption_prefix, - subset.non_preference_caption_suffix, + preference=subset.preference, + preference_caption_prefix=subset.preference_caption_prefix, + preference_caption_suffix=subset.preference_caption_suffix, + non_preference_caption_prefix=subset.non_preference_caption_prefix, + non_preference_caption_suffix=subset.non_preference_caption_suffix, ) db_subsets.append(db_subset) diff --git a/pytest.ini b/pytest.ini index 484d3aef..34b7e9c1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -6,3 +6,4 @@ filterwarnings = ignore::DeprecationWarning ignore::UserWarning ignore::FutureWarning +pythonpath = . diff --git a/sd3_train_network.py b/sd3_train_network.py index cdb7aa4e..fdb2b356 100644 --- a/sd3_train_network.py +++ b/sd3_train_network.py @@ -323,7 +323,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): weight_dtype, train_unet, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) @@ -389,7 +389,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer): target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) - return model_pred, target, timesteps, weighting + return model_pred, noisy_model_input, target, timesteps, weighting def post_process_loss(self, loss, args, timesteps, noise_scheduler): return loss diff --git a/tests/library/test_custom_train_functions_diffusion_dpo.py b/tests/library/test_custom_train_functions_diffusion_dpo.py new file mode 100644 index 00000000..140ecf54 --- /dev/null +++ b/tests/library/test_custom_train_functions_diffusion_dpo.py @@ -0,0 +1,164 @@ +import torch +from unittest.mock import Mock + +from library.custom_train_functions import diffusion_dpo_loss + +def test_diffusion_dpo_loss_basic(): + batch_size = 4 + channels = 4 + height, width = 64, 64 + + # Create dummy loss tensor + loss = torch.rand(batch_size, channels, height, width) + + # Mock the call_unet and apply_loss functions + mock_unet_output = torch.rand(batch_size, channels, height, width) + call_unet = Mock(return_value=mock_unet_output) + + mock_loss_output = torch.rand(batch_size, channels, height, width) + apply_loss = Mock(return_value=mock_loss_output) + + beta_dpo = 0.1 + + result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta_dpo) + + # Check return types + assert isinstance(result, torch.Tensor) + assert isinstance(metrics, dict) + + # Check expected metrics are present + expected_keys = ["total_loss", "raw_model_loss", "ref_loss", "implicit_acc"] + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], float) + + # Verify mocks were called correctly + call_unet.assert_called_once() + apply_loss.assert_called_once_with(mock_unet_output) + +def test_diffusion_dpo_loss_shapes(): + # Test with different tensor shapes + shapes = [ + (2, 4, 32, 32), # Small tensor + (4, 16, 64, 64), # Medium tensor + (6, 32, 128, 128), # Larger tensor + ] + + for shape in shapes: + loss = torch.rand(*shape) + + # Create mocks + mock_unet_output = torch.rand(*shape) + call_unet = Mock(return_value=mock_unet_output) + + mock_loss_output = torch.rand(*shape) + apply_loss = Mock(return_value=mock_loss_output) + + result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1) + + # The result should be a scalar tensor + assert result.shape == torch.Size([shape[0] // 2]) + + # All metrics should be scalars + for val in metrics.values(): + assert isinstance(val, float) + +def test_diffusion_dpo_loss_beta_values(): + batch_size = 2 + channels = 4 + height, width = 64, 64 + + loss = torch.rand(batch_size, channels, height, width) + + # Create consistent mock returns + mock_unet_output = torch.rand(batch_size, channels, height, width) + mock_loss_output = torch.rand(batch_size, channels, height, width) + + # Test with different beta values + beta_values = [0.0, 0.1, 1.0, 10.0] + results = [] + + for beta in beta_values: + call_unet = Mock(return_value=mock_unet_output) + apply_loss = Mock(return_value=mock_loss_output) + + result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, beta) + results.append(result.item()) + + # With increasing beta, results should be different + # This test checks that beta affects the output + assert len(set(results)) > 1, "Different beta values should produce different results" + +def test_diffusion_dpo_implicit_acc(): + batch_size = 4 + channels = 4 + height, width = 64, 64 + + # Create controlled test data where winners have lower loss + w_loss = torch.ones(batch_size//2, channels, height, width) * 0.2 + l_loss = torch.ones(batch_size//2, channels, height, width) * 0.8 + loss = torch.cat([w_loss, l_loss], dim=0) + + # Make the reference loss similar but with less difference + ref_w_loss = torch.ones(batch_size//2, channels, height, width) * 0.3 + ref_l_loss = torch.ones(batch_size//2, channels, height, width) * 0.7 + ref_loss = torch.cat([ref_w_loss, ref_l_loss], dim=0) + + call_unet = Mock(return_value=torch.zeros_like(loss)) # Dummy, won't be used + apply_loss = Mock(return_value=ref_loss) + + # With a positive beta, model_diff > ref_diff should lead to high implicit accuracy + result, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, 1.0) + + # Implicit accuracy should be high (model correctly identifies preferences) + assert metrics["implicit_acc"] > 0.5 + +def test_diffusion_dpo_gradient_flow(): + batch_size = 4 + channels = 4 + height, width = 64, 64 + + # Create loss tensor that requires gradients + loss = torch.rand(batch_size, channels, height, width, requires_grad=True) + + # Create mock outputs + mock_unet_output = torch.rand(batch_size, channels, height, width) + call_unet = Mock(return_value=mock_unet_output) + + mock_loss_output = torch.rand(batch_size, channels, height, width) + apply_loss = Mock(return_value=mock_loss_output) + + # Compute loss + result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1) + + # Check that gradients flow + result.mean().backward() + + # Verify gradients flowed through + assert loss.grad is not None + +def test_diffusion_dpo_no_ref_grad(): + batch_size = 4 + channels = 4 + height, width = 64, 64 + + loss = torch.rand(batch_size, channels, height, width, requires_grad=True) + + # Set up mock that tracks if it was called with no_grad + mock_unet_output = torch.rand(batch_size, channels, height, width) + call_unet = Mock(return_value=mock_unet_output) + + mock_loss_output = torch.rand(batch_size, channels, height, width, requires_grad=True) + apply_loss = Mock(return_value=mock_loss_output) + + # Run function + result, _ = diffusion_dpo_loss(loss, call_unet, apply_loss, 0.1) + result.mean().backward() + + # Check that the reference loss has no gradients (was computed with torch.no_grad()) + # This is a bit tricky to test directly, but we can verify call_unet was called + call_unet.assert_called_once() + apply_loss.assert_called_once() + + # The mock_loss_output should not receive gradients as it's used inside torch.no_grad() + assert mock_loss_output.grad is None diff --git a/tests/library/test_custom_train_functions_mapo.py b/tests/library/test_custom_train_functions_mapo.py new file mode 100644 index 00000000..88228f72 --- /dev/null +++ b/tests/library/test_custom_train_functions_mapo.py @@ -0,0 +1,116 @@ +import torch +import numpy as np + +from library.custom_train_functions import mapo_loss + + +def test_mapo_loss_basic(): + batch_size = 4 + channels = 4 + height, width = 64, 64 + + # Create dummy loss tensor with shape [B, C, H, W] + loss = torch.rand(batch_size, channels, height, width) + mapo_weight = 0.5 + + result, metrics = mapo_loss(loss, mapo_weight) + + # Check return types + assert isinstance(result, torch.Tensor) + assert isinstance(metrics, dict) + + # Check required metrics are present + expected_keys = ["total_loss", "ratio_loss", "model_losses_w", "model_losses_l", "win_score", "lose_score"] + for key in expected_keys: + assert key in metrics + assert isinstance(metrics[key], float) + + +def test_mapo_loss_different_shapes(): + # Test with different tensor shapes + shapes = [ + (2, 4, 32, 32), # Small tensor + (4, 16, 64, 64), # Medium tensor + (6, 32, 128, 128), # Larger tensor + ] + + for shape in shapes: + loss = torch.rand(*shape) + result, metrics = mapo_loss(loss, 0.5) + + # The result should be a scalar tensor + assert result.shape == torch.Size([]) + + # All metrics should be scalars + for val in metrics.values(): + assert np.isscalar(val) + + +def test_mapo_loss_with_zero_weight(): + loss = torch.rand(4, 3, 64, 64) + result, metrics = mapo_loss(loss, 0.0) + + # With zero mapo_weight, ratio_loss should be zero + assert metrics["ratio_loss"] == 0.0 + + # result should be equal to mean of model_losses_w + assert torch.isclose(result, torch.tensor(metrics["model_losses_w"])) + + +def test_mapo_loss_with_different_timesteps(): + loss = torch.rand(4, 4, 32, 32) + + # Test with different timestep values + timesteps = [1, 10, 100, 1000] + + for ts in timesteps: + result, metrics = mapo_loss(loss, 0.5, ts) + + # Check that the results are different for different timesteps + if ts > 1: + result_prev, metrics_prev = mapo_loss(loss, 0.5, ts // 10) + # Log odds should be affected by timesteps, so ratio_loss should change + assert metrics["ratio_loss"] != metrics_prev["ratio_loss"] + + +def test_mapo_loss_win_loss_scores(): + # Create a controlled input where win losses are lower than lose losses + batch_size = 4 + channels = 4 + height, width = 64, 64 + + # Create losses where winning examples have lower loss + w_loss = torch.ones(batch_size // 2, channels, height, width) * 0.1 + l_loss = torch.ones(batch_size // 2, channels, height, width) * 0.9 + + # Concatenate to create the full loss tensor + loss = torch.cat([w_loss, l_loss], dim=0) + + # Run the function + result, metrics = mapo_loss(loss, 0.5) + + # Win score should be higher than lose score (better performance) + assert metrics["win_score"] > metrics["lose_score"] + + # Model losses for winners should be lower + assert metrics["model_losses_w"] < metrics["model_losses_l"] + + +def test_mapo_loss_gradient_flow(): + # Test that gradients flow through the loss function + batch_size = 4 + channels = 4 + height, width = 64, 64 + + # Create a loss tensor that requires grad + loss = torch.rand(batch_size, channels, height, width, requires_grad=True) + mapo_weight = 0.5 + + # Compute loss + result, _ = mapo_loss(loss, mapo_weight) + + # Check that gradients flow + result.backward() + + # If gradients flow, loss.grad should not be None + assert loss.grad is not None diff --git a/train_network.py b/train_network.py index 040d5204..7711f4ca 100644 --- a/train_network.py +++ b/train_network.py @@ -43,6 +43,8 @@ from library.custom_train_functions import ( add_v_prediction_like_loss, apply_debiased_estimation, apply_masked_loss, + diffusion_dpo_loss, + mapo_loss ) from library.utils import setup_logging, add_logging_arguments @@ -255,18 +257,18 @@ class NetworkTrainer: def get_noise_pred_and_target( self, - args, - accelerator, + args: argparse.Namespace, + accelerator: Accelerator, noise_scheduler, - latents, - batch, + latents: torch.FloatTensor, + batch: dict[str, torch.Tensor], text_encoder_conds, unet, network, - weight_dtype, - train_unet, + weight_dtype: torch.dtype, + train_unet: bool, is_train=True, - ): + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) @@ -321,9 +323,9 @@ class NetworkTrainer: network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - return noise_pred, target, timesteps, None + return noise_pred, noisy_latents, target, timesteps, None - def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: + def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: @@ -380,10 +382,12 @@ class NetworkTrainer: is_train=True, train_text_encoder=True, train_unet=True, - ) -> torch.Tensor: + multipliers=1.0, + ) -> tuple[torch.Tensor, dict[str, float | int]]: """ Process a batch for the network """ + metrics: dict[str, float | int] = {} with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) @@ -446,7 +450,7 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( + noise_pred, noisy_latents, target, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -468,85 +472,34 @@ class NetworkTrainer: loss = apply_masked_loss(loss, batch) if args.beta_dpo is not None: - model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) - model_loss_w, model_loss_l = model_loss.chunk(2) - raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean()) - model_diff = model_loss_w - model_loss_l - - # ref loss - with torch.no_grad(): - # disable network for reference + def call_unet(): accelerator.unwrap_model(network).set_multiplier(0.0) + ref_noise_pred = self.call_unet( + args, + accelerator, + unet, + noisy_latents.requires_grad_(train_unet), + timesteps, + text_encoder_conds, + batch, + weight_dtype, + ) - with accelerator.autocast(): - ref_noise_pred = self.call_unet( - args, - accelerator, - unet, - noisy_latents.requires_grad_(train_unet), - timesteps, - text_encoder_conds, - batch, - weight_dtype, - ) + # reset network multipliers + accelerator.unwrap_model(network).set_multiplier(1.0) + return ref_noise_pred + def apply_loss(ref_noise_pred): + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) ref_loss = train_util.conditional_loss( ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): ref_loss = apply_masked_loss(ref_loss, batch) + return ref_loss - ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) - ref_losses_w, ref_losses_l = ref_loss.chunk(2) - ref_diff = ref_losses_w - ref_losses_l - raw_ref_loss = ref_loss.mean() - - # reset network multipliers - accelerator.unwrap_model(network).set_multiplier(multipliers) - - scale_term = -0.5 * args.beta_dpo - inside_term = scale_term * (model_diff - ref_diff) - loss = -1 * torch.nn.functional.logsigmoid(inside_term) - - implicit_acc = (inside_term > 0).sum().float() / inside_term.size(0) - implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) - - accelerator.log({ - "total_loss": model_loss.detach().mean().item(), - "raw_model_loss": raw_model_loss.detach().mean().item(), - "ref_loss": raw_ref_loss.detach().item(), - "implicit_acc": implicit_acc.detach().item(), - }, step=global_step) + loss, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, args.beta_dpo) elif args.mapo_weight is not None: - model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) - - snr = 0.5 - model_losses_w, model_losses_l = model_loss.chunk(2) - log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - ( - snr * model_losses_l - ) / (torch.exp(snr * model_losses_l) - 1) - - # Ratio loss. - # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. - ratio = torch.nn.functional.logsigmoid(log_odds * noise_scheduler.config.num_train_timesteps) - ratio_losses = args.mapo_weight * ratio - - # Full MaPO loss - loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape)))) - - accelerator.log({ - "total_loss": loss.detach().mean().item(), - "ratio_loss": -ratio_losses.mean().detach().item(), - "model_losses_w": model_losses_w.mean().detach().item(), - "model_losses_l": model_losses_l.mean().detach().item(), - "win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)) - .mean() - .detach() - .item(), - "lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)) - .mean() - .detach() - .item() - }, step=global_step) + loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps) else: loss = loss.mean([1, 2, 3]) @@ -555,7 +508,7 @@ class NetworkTrainer: loss = self.post_process_loss(loss, args, timesteps, noise_scheduler) - return loss.mean() + return loss.mean(), metrics def train(self, args): session_id = random.randint(0, 2**32) @@ -1482,7 +1435,7 @@ class NetworkTrainer: # preprocess batch for each model self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True) - loss = self.process_batch( + loss, batch_metrics = self.process_batch( batch, text_encoders, unet, @@ -1584,7 +1537,7 @@ class NetworkTrainer: mean_grad_norm, mean_combined_norm, ) - self.step_logging(accelerator, logs, global_step, epoch + 1) + self.step_logging(accelerator, {**logs, **batch_metrics}, global_step, epoch + 1) # VALIDATION PER STEP: global_step is already incremented # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...