From 9a2101a0401ea9f735ef4c683e4bee21b2270255 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 30 Apr 2025 03:34:19 -0400 Subject: [PATCH] Add DDO loss --- flux_train_network.py | 1 + library/custom_offloading_utils.py | 23 +-- library/custom_train_functions.py | 265 ++++++++++------------------- train_network.py | 119 ++++++------- 4 files changed, 160 insertions(+), 248 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index c619afac..031d9b69 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -384,6 +384,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode with torch.set_grad_enabled(is_train), accelerator.autocast(): # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) + accelerator.unwrap_model(unet).prepare_block_swap_before_forward() model_pred = unet( img=img, img_ids=img_ids, diff --git a/library/custom_offloading_utils.py b/library/custom_offloading_utils.py index 84c2b743..6411e0b7 100644 --- a/library/custom_offloading_utils.py +++ b/library/custom_offloading_utils.py @@ -42,19 +42,20 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye torch.cuda.current_stream().synchronize() # this prevents the illegal loss value - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - # cuda to cpu - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.record_stream(stream) - module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) + with torch.no_grad(): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # cuda to cpu + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.record_stream(stream) + module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) - stream.synchronize() + stream.synchronize() - # cpu to cuda - for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: - cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) - module_to_cuda.weight.data = cuda_data_view + # cpu to cuda + for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: + cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True) + module_to_cuda.weight.data = cuda_data_view stream.synchronize() torch.cuda.current_stream().synchronize() # this prevents the illegal loss value diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 21af1184..0689a159 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -505,41 +505,23 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: loss = loss * mask_image return loss -def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], apply_loss: Callable[[torch.Tensor], torch.Tensor], beta_dpo: float): +def diffusion_dpo_loss(loss: torch.Tensor, ref_loss: Tensor, beta_dpo: float): """ - DPO loss + Diffusion DPO loss Args: - loss: pairs of w, l losses B//2, C, H, W - call_unet: function to call unet - apply_loss: function to apply loss + loss: pairs of w, l losses B//2 + ref_loss: ref pairs of w, l losses B//2 beta_dpo: beta_dpo weight - - Returns: - tuple: - - loss: mean loss of C, H, W - - metrics: - - total_loss: mean loss of C, H, W - - raw_model_loss: mean loss of C, H, W - - ref_loss: mean loss of C, H, W - - implicit_acc: accumulated implicit of C, H, W - """ - model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) - model_loss_w, model_loss_l = model_loss.chunk(2) - raw_model_loss = 0.5 * (model_loss_w.mean() + model_loss_l.mean()) - model_diff = model_loss_w - model_loss_l - - # ref loss - with torch.no_grad(): - ref_noise_pred = call_unet() - ref_loss = apply_loss(ref_noise_pred) - ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape)))) - ref_losses_w, ref_losses_l = ref_loss.chunk(2) - ref_diff = ref_losses_w - ref_losses_l - raw_ref_loss = ref_loss.mean() + loss_w, loss_l = loss.chunk(2) + raw_loss = 0.5 * (loss_w.mean(dim=1) + loss_l.mean(dim=1)) + model_diff = loss_w - loss_l + ref_losses_w, ref_losses_l = ref_loss.chunk(2) + ref_diff = ref_losses_w - ref_losses_l + raw_ref_loss = ref_loss.mean(dim=1) scale_term = -0.5 * beta_dpo inside_term = scale_term * (model_diff - ref_diff) @@ -549,10 +531,10 @@ def diffusion_dpo_loss(loss: torch.Tensor, call_unet: Callable[[],torch.Tensor], implicit_acc += 0.5 * (inside_term == 0).sum().float() / inside_term.size(0) metrics = { - "total_loss": model_loss.detach().mean().item(), - "raw_model_loss": raw_model_loss.detach().mean().item(), - "ref_loss": raw_ref_loss.detach().item(), - "implicit_acc": implicit_acc.detach().item(), + "loss/diffusion_dpo_total_loss": loss.detach().mean().item(), + "loss/diffusion_dpo_raw_loss": raw_loss.detach().mean().item(), + "loss/diffusion_dpo_ref_loss": raw_ref_loss.detach().item(), + "loss/diffusion_dpo_implicit_acc": implicit_acc.detach().item(), } return loss, metrics @@ -563,28 +545,15 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) Args: loss: pairs of w, l losses B//2, C, H, W - mapo_loss: mapo weight + mapo_weight: mapo weight num_train_timesteps: number of timesteps - - Returns: - tuple: - - loss: mean loss of C, H, W - - metrics: - - total_loss: mean loss of C, H, W - - ratio_loss: mean ratio loss of C, H, W - - model_losses_w: mean loss of w losses of C, H, W - - model_losses_l: mean loss of l losses of C, H, W - - win_score : mean win score of C, H, W - - lose_score : mean lose score of C, H, W - """ - model_loss = loss.mean(dim=list(range(1, len(loss.shape)))) snr = 0.5 - model_losses_w, model_losses_l = model_loss.chunk(2) - log_odds = (snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1) - ( - snr * model_losses_l - ) / (torch.exp(snr * model_losses_l) - 1) + loss_w, loss_l = loss.chunk(2) + log_odds = (snr * loss_w) / (torch.exp(snr * loss_w) - 1) - ( + snr * loss_l + ) / (torch.exp(snr * loss_l) - 1) # Ratio loss. # By multiplying T to the inner term, we try to maximize the margin throughout the overall denoising process. @@ -592,141 +561,91 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) ratio_losses = mapo_weight * ratio # Full MaPO loss - loss = model_losses_w.mean(dim=list(range(1, len(model_losses_w.shape)))) - ratio_losses.mean(dim=list(range(1, len(ratio_losses.shape)))) + loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1) metrics = { "total_loss": loss.detach().mean().item(), "ratio_loss": -ratio_losses.detach().mean().item(), - "model_losses_w": model_losses_w.detach().mean().item(), - "model_losses_l": model_losses_l.detach().mean().item(), - "win_score": ((snr * model_losses_w) / (torch.exp(snr * model_losses_w) - 1)).detach().mean().item(), - "lose_score": ((snr * model_losses_l) / (torch.exp(snr * model_losses_l) - 1)).detach().mean().item(), + "model_losses_w": loss_w.detach().mean().item(), + "model_losses_l": loss_l.detach().mean().item(), + "win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), + "lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), } return loss, metrics -class FlowMatchingDDOLoss(nn.Module): - def __init__(self, alpha=4.0, beta=0.05): - super().__init__() - self.alpha = alpha - self.beta = beta - - def forward( - self, v_theta: Tensor, v_theta_ref: Tensor, v_target: Tensor, time=None - ): - """ - Compute DDO loss for flow matching models - - Args: - v_theta: Vector field predicted by target model - v_theta_ref: Vector field predicted by reference model - v_target: Target vector field (e.g., straight-line for rectified flow) - time: Time parameter t - - Returns: - DDO loss value - """ - # For flow matching, error is based on vector field difference - error_theta = torch.sum((v_theta - v_target) ** 2, dim=[1, 2, 3]) - error_theta_ref = torch.sum((v_theta_ref - v_target) ** 2, dim=[1, 2, 3]) - - # Likelihood ratio approximation - delta = error_theta_ref - error_theta - scaled_delta = self.beta * delta - - # Split batch into real and fake parts - batch_size = v_theta.shape[0] - half_batch = batch_size // 2 - - real_delta = scaled_delta[:half_batch] - fake_delta = scaled_delta[half_batch:] - - real_loss = -F.logsigmoid(real_delta).mean() - fake_loss = -F.logsigmoid(-fake_delta).mean() - - loss = real_loss + self.alpha * fake_loss - - return loss - -def compute_target_velocity(x_t: Tensor, t: Tensor): +def ddo_loss( + loss: Tensor, + ref_loss: Tensor, + ddo_alpha: float=4.0, + ddo_beta: float=0.05, + weighting: Tensor | None=None +): """ - Compute the target velocity vector field for flow matching. - - For rectified flow, the target velocity is the straight-line path derivative. - + Calculate DDO loss for flow matching diffusion models. + + This implementation follows the paper's approach: + 1. Use prediction errors as proxy for log likelihood ratio + 2. Apply sigmoid to create a discriminator from this ratio + 3. Optimize using the standard GAN discriminator loss + Args: - x_t: Points along the path at time t (batch_size, channels, height, width) - t: Time values in [0,1] (batch_size,) - + loss: loss B, N + ref_loss: ref loss B, N + ddo_alpha: Weight for the fake sample term + ddo_beta: Scaling factor for the likelihood ratio + weighting: Optional time-dependent weighting + Returns: - Target velocity vectors v(x_t, t) for flow matching + The DDO loss value """ - batch_size = x_t.shape[0] - - # Get corresponding data and noise endpoints - with torch.no_grad(): - # For each interpolated point, we need the endpoints of its path - # In practice, these might come from a cache or be passed as arguments - x1 = get_data_endpoints(x_t, t) # Real data endpoint (t=0) - x0 = get_noise_endpoints(x_t, t) # Noise endpoint (t=1) - - # Reshape t for broadcasting - t = t.view(batch_size, 1, 1, 1) - - # For standard rectified flow, the target velocity is constant along the path: - # v(x_t, t) = x1 - x0 - v_target = x1 - x0 - - # For time-dependent velocity fields (non-rectified), we would scale by time: - # v_target = v_target * g(t) # where g(t) is a time-dependent scaling function - - return v_target - - -def get_data_endpoints(x_t: Tensor, t: Tensor): - """ - Get the data endpoints (t=0) for the given points on the path. - - For training with real data, this would typically use the encoded real data. - For inference or when using generated endpoints, we'd solve for them. - - Args: - x_t: Points on the path at time t - t: Time values - - Returns: - The data endpoints (x at t=0) - """ - # Solve for x1 using the straight-line path: x_t = (1-t)*x1 + t*x0 - t = t.view(-1, 1, 1, 1) - x0 = torch.randn_like(x_t) # Noise endpoint - - # Solve for x1: x1 = (x_t - t*x0) / (1-t) - # Add small epsilon to prevent division by zero - epsilon = 1e-8 - x1 = (x_t - t * x0) / (torch.clamp(1 - t, min=epsilon)) - - return x1 - - -def get_noise_endpoints(x_t: Tensor, t: Tensor): - """ - Get the noise endpoints (t=1) for the given points on the path. - - For standard rectified flow, this is typically Gaussian noise. - - Args: - x_t: Points on the path at time t - t: Time values - - Returns: - The noise endpoints (x at t=1) - """ - - # Generate noise samples matching the shape of x_t - x0 = torch.randn_like(x_t) - - return x0 + # Calculate per-sample MSE between predictions and target + # Flatten spatial and channel dimensions, keeping batch dimension + # target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1) + # ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1) + + # Apply weighting if provided (e.g., for time-dependent importance) + if weighting is not None: + if isinstance(weighting, tuple): + # Use first element if it's a tuple + weighting = weighting[0] + if weighting.ndim > 1: + # Ensure weighting is the right shape + weighting = weighting.view(-1) + loss = loss * weighting + ref_loss = ref_loss * weighting + + # Calculate the log likelihood ratio + # For flow matching, lower error = higher likelihood + # So the log ratio is proportional to negative of error difference + log_ratio = ddo_beta * (ref_loss - loss) + + # Divide batch into real and fake samples (mid-point split) + # In this implementation, the entire batch is treated as real samples + # and each sample is compared against its own reference prediction + # This approach works because the reference model (with LoRA disabled) + # produces predictions that serve as the "fake" distribution + + # Loss for real samples: maximize log σ(ratio) + real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio) + real_loss = real_loss_terms.mean() + + # Loss for fake samples: maximize log(1-σ(ratio)) + # Since we're using the same batch for both real and fake, + # we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference + fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio) + fake_loss = ddo_alpha * fake_loss_terms.mean() + + total_loss = real_loss + fake_loss + + metrics = { + "loss/ddo_real": real_loss.detach().item(), + "loss/ddo_fake": fake_loss.detach().item(), + "loss/ddo_total": total_loss.detach().item(), + "ddo_log_ratio_mean": log_ratio.detach().mean().item(), + } + + return total_loss, metrics """ diff --git a/train_network.py b/train_network.py index 5731463a..40e24740 100644 --- a/train_network.py +++ b/train_network.py @@ -37,6 +37,7 @@ import library.huggingface_util as huggingface_util import library.custom_train_functions as custom_train_functions from library.custom_train_functions import ( apply_snr_weight, + ddo_loss, get_weighted_text_embeddings, prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, @@ -45,8 +46,7 @@ from library.custom_train_functions import ( apply_masked_loss, diffusion_dpo_loss, mapo_loss, - FlowMatchingDDOLoss, - compute_target_velocity, + calculate_ddo_loss_for_dit_flow_matching, ) from library.utils import setup_logging, add_logging_arguments @@ -270,7 +270,7 @@ class NetworkTrainer: weight_dtype: torch.dtype, train_unet: bool, is_train=True, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) @@ -324,8 +324,8 @@ class NetworkTrainer: ) network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) - - return noise_pred, noisy_latents, target, timesteps, None + sigmas = timesteps / noise_scheduler.config.num_train_timesteps + return noise_pred, noisy_latents, target, sigmas, timesteps, None def post_process_loss(self, loss: torch.Tensor, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor: if args.min_snr_gamma: @@ -452,7 +452,8 @@ class NetworkTrainer: text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target - noise_pred, noisy_latents, target, timesteps, weighting = self.get_noise_pred_and_target( + + noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, @@ -466,17 +467,52 @@ class NetworkTrainer: is_train=is_train, ) - if args.ddo_beta is not None or args.ddo_alpha is not None: - # Compute DDO loss - ddo_loss = FlowMatchingDDOLoss(alpha=args.ddo_beta or 4.0, beta=args.ddo_alpha or 0.05) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) + if weighting is not None: + loss = loss * weighting + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) - accelerator.unwrap_model(network).set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): - ref_noise_pred, _noisy_latents, ref_target, ref_timesteps, _weighting = self.get_noise_pred_and_target( + if args.ddo_beta is not None or args.ddo_alpha is not None: + with torch.no_grad(): + accelerator.unwrap_model(network).set_multiplier(0.0) + ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target( args, accelerator, noise_scheduler, - torch.rand_like(latents), + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=False, + ) + + # reset network multipliers + accelerator.unwrap_model(network).set_multiplier(1.0) + + # Apply DDO loss + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c) + loss, metrics_ddo = ddo_loss( + loss, + ref_loss, + args.ddo_alpha or 4.0, + args.ddo_beta or 0.05, + weighting + ) + metrics = {**metrics, **metrics_ddo} + elif args.beta_dpo is not None: + with torch.no_grad(): + accelerator.unwrap_model(network).set_multiplier(0.0) + ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, batch, text_encoder_conds, unet, @@ -485,60 +521,15 @@ class NetworkTrainer: train_unet, is_train=is_train, ) - - # reset network multipliers - accelerator.unwrap_model(network).set_multiplier(1.0) - - # Combine real and fake batches - combined_latents = torch.cat([noise_pred, ref_noise_pred], dim=0) - combined_t = torch.cat([timesteps, ref_timesteps], dim=0) - - # Compute target vector field (straight path for rectified flow) - v_target = compute_target_velocity(combined_latents, combined_t) - v_theta = noise_pred - v_theta_ref = ref_noise_pred - - loss = ddo_loss(v_theta, v_theta_ref, v_target, combined_t) - else: - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c) - - if weighting is not None: - loss = loss * weighting - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - loss = apply_masked_loss(loss, batch) - - if args.beta_dpo is not None: - def call_unet(): - accelerator.unwrap_model(network).set_multiplier(0.0) - with torch.no_grad(), accelerator.autocast(): - ref_noise_pred, _noisy_latents, ref_target, ref_timesteps, _weighting = self.get_noise_pred_and_target( - args, - accelerator, - noise_scheduler, - torch.rand_like(latents), - batch, - text_encoder_conds, - unet, - network, - weight_dtype, - train_unet, - is_train=is_train, - ) - # reset network multipliers accelerator.unwrap_model(network).set_multiplier(1.0) - return ref_noise_pred - def apply_loss(ref_noise_pred): - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) - ref_loss = train_util.conditional_loss( - ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c - ) - if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): - ref_loss = apply_masked_loss(ref_loss, batch) - return ref_loss - loss, metrics = diffusion_dpo_loss(loss, call_unet, apply_loss, args.beta_dpo) + huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + ref_loss = train_util.conditional_loss( + ref_noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) + + loss, metrics = diffusion_dpo_loss(loss, ref_loss, args.beta_dpo) elif args.mapo_weight is not None: loss, metrics = mapo_loss(loss, args.mapo_weight, noise_scheduler.config.num_train_timesteps) else: