diff --git a/flux_train_network.py b/flux_train_network.py index 031d9b69..d7bff288 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -347,16 +347,23 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): weight_dtype: torch.dtype, train_unet: bool, is_train=True, - ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: + timesteps: torch.FloatTensor | None=None, + ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # get noisy model input and timesteps - noisy_model_input, timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep( + noisy_model_input, rand_timesteps, sigmas = flux_train_utils.get_noisy_model_input_and_timestep( args, noise_scheduler, latents, noise, accelerator.device, weight_dtype ) + if timesteps is None: + timesteps = rand_timesteps + else: + # Convert timesteps into sigmas + sigmas: torch.FloatTensor = timesteps - noise_scheduler.config.num_train_timesteps + # pack latents and get img_ids packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4 packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2 diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 43893de0..7194b5c3 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -568,85 +568,36 @@ def mapo_loss(loss: torch.Tensor, mapo_weight: float, num_train_timesteps=1000) loss = loss_w.mean(dim=1) - ratio_losses.mean(dim=1) metrics = { - "total_loss": loss.detach().mean().item(), - "ratio_loss": -ratio_losses.detach().mean().item(), - "model_losses_w": loss_w.detach().mean().item(), - "model_losses_l": loss_l.detach().mean().item(), - "win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), - "lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), + "loss/diffusion_dpo_total": loss.detach().mean().item(), + "loss/diffusion_dpo_ratio": -ratio_losses.detach().mean().item(), + "loss/diffusion_dpo_w_loss": loss_w.detach().mean().item(), + "loss/diffusion_dpo_l_loss": loss_l.detach().mean().item(), + "loss/diffusion_dpo_win_score": ((snr * loss_w) / (torch.exp(snr * loss_w) - 1)).detach().mean().item(), + "loss/diffusion_dpo_lose_score": ((snr * loss_l) / (torch.exp(snr * loss_l) - 1)).detach().mean().item(), } return loss, metrics - -def ddo_loss(loss: Tensor, ref_loss: Tensor, ddo_alpha: float = 4.0, ddo_beta: float = 0.05, weighting: Tensor | None = None): - """ - Calculate DDO loss for flow matching diffusion models. - - This implementation follows the paper's approach: - 1. Use prediction errors as proxy for log likelihood ratio - 2. Apply sigmoid to create a discriminator from this ratio - 3. Optimize using the standard GAN discriminator loss - - Args: - loss: loss B, N - ref_loss: ref loss B, N - ddo_alpha: Weight for the fake sample term - ddo_beta: Scaling factor for the likelihood ratio - weighting: Optional time-dependent weighting - - Returns: - The DDO loss value - """ - # Calculate per-sample MSE between predictions and target - # Flatten spatial and channel dimensions, keeping batch dimension - # target_error = ((noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1) - # ref_error = ((ref_noise_pred - target)**2).reshape(batch_size, -1).mean(dim=1) - - # Apply weighting if provided (e.g., for time-dependent importance) - if weighting is not None: - if isinstance(weighting, tuple): - # Use first element if it's a tuple - weighting = weighting[0] - if weighting.ndim > 1: - # Ensure weighting is the right shape - weighting = weighting.view(-1) - loss = loss * weighting - ref_loss = ref_loss * weighting - - # Calculate the log likelihood ratio - # For flow matching, lower error = higher likelihood - # So the log ratio is proportional to negative of error difference +def ddo_loss(loss, ref_loss, ddo_alpha: float = 4.0, ddo_beta: float = 0.05): + ref_loss = ref_loss.detach() # Ensure no gradients to reference log_ratio = ddo_beta * (ref_loss - loss) - - # Divide batch into real and fake samples (mid-point split) - # In this implementation, the entire batch is treated as real samples - # and each sample is compared against its own reference prediction - # This approach works because the reference model (with LoRA disabled) - # produces predictions that serve as the "fake" distribution - - # Loss for real samples: maximize log σ(ratio) - real_loss_terms = -torch.nn.functional.logsigmoid(log_ratio) - real_loss = real_loss_terms.mean() - - # Loss for fake samples: maximize log(1-σ(ratio)) - # Since we're using the same batch for both real and fake, - # we interpret this as maximizing log(1-σ(ratio)) for the samples when viewed from reference - fake_loss_terms = -torch.nn.functional.logsigmoid(-log_ratio) - fake_loss = ddo_alpha * fake_loss_terms.mean() - + real_loss = -torch.log(torch.sigmoid(log_ratio) + 1e-6).mean() + fake_loss = -ddo_alpha * torch.log(1 - torch.sigmoid(log_ratio) + 1e-6).mean() total_loss = real_loss + fake_loss metrics = { "loss/ddo_real": real_loss.detach().item(), "loss/ddo_fake": fake_loss.detach().item(), "loss/ddo_total": total_loss.detach().item(), - "ddo_log_ratio_mean": log_ratio.detach().mean().item(), + "loss/ddo_sigmoid_log_ratio": torch.sigmoid(log_ratio).mean().item(), } + # logger.debug(f"loss mean: {loss.mean().item()}, ref_loss mean: {ref_loss.mean().item()}") + # logger.debug(f"difference: {(ref_loss - loss).mean().item()}") + # logger.debug(f"log_ratio range: {log_ratio.min().item()} to {log_ratio.max().item()}") + # logger.debug(f"sigmoid(log_ratio) mean: {torch.sigmoid(log_ratio).mean().item()}") return total_loss, metrics - """ ########################################## # Perlin Noise diff --git a/train_network.py b/train_network.py index 1748df92..6afc50c3 100644 --- a/train_network.py +++ b/train_network.py @@ -270,10 +270,14 @@ class NetworkTrainer: weight_dtype: torch.dtype, train_unet: bool, is_train=True, + timesteps=None ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.IntTensor, torch.Tensor | None]: # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, rand_timesteps = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + + if timesteps is None: + timesteps = rand_timesteps # ensure the hidden state will require grad if args.gradient_checkpointing: @@ -475,34 +479,34 @@ class NetworkTrainer: loss = apply_masked_loss(loss, batch) if args.ddo_beta is not None or args.ddo_alpha is not None: - with torch.no_grad(): - accelerator.unwrap_model(network).set_multiplier(0.0) - ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, _weighting = self.get_noise_pred_and_target( - args, - accelerator, - noise_scheduler, - latents, - batch, - text_encoder_conds, - unet, - network, - weight_dtype, - train_unet, - is_train=False, - ) + accelerator.unwrap_model(network).set_multiplier(0.0) + ref_noise_pred, ref_noisy_latents, ref_target, ref_sigmas, ref_timesteps, ref_weighting = self.get_noise_pred_and_target( + args, + accelerator, + noise_scheduler, + latents, + batch, + text_encoder_conds, + unet, + network, + weight_dtype, + train_unet, + is_train=False, + timesteps=timesteps, + ) - # reset network multipliers - accelerator.unwrap_model(network).set_multiplier(1.0) + # reset network multipliers + accelerator.unwrap_model(network).set_multiplier(1.0) - # Apply DDO loss - huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) + huber_c = train_util.get_huber_threshold_if_needed(args, ref_timesteps, noise_scheduler) ref_loss= train_util.conditional_loss(ref_noise_pred.float(), ref_target.float(), args.loss_type, "none", huber_c) + if weighting is not None and ref_weighting is not None: + ddo_weighting = weighting * ref_weighting loss, metrics_ddo = ddo_loss( - loss.mean(dim=(1, 2, 3)), - ref_loss.mean(dim=(1, 2, 3)), + loss.mean(dim=(1, 2, 3)) * (weighting if weighting is not None else 1), + ref_loss.mean(dim=(1, 2, 3)) * (ref_weighting if ref_weighting is not None else 1), args.ddo_alpha or 4.0, args.ddo_beta or 0.05, - weighting ) metrics = {**metrics, **metrics_ddo} elif args.beta_dpo is not None: