From db05136480069da65e80b55f564f9c38487b29e2 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 3 Jun 2025 20:55:29 -0400 Subject: [PATCH] Fix sigmas/timesteps --- library/custom_train_functions.py | 2 -- library/flux_train_utils.py | 4 ++-- library/train_util.py | 1 - tests/library/test_flux_train_utils.py | 3 +++ 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 0d3da3a1..bab63f0c 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -966,8 +966,6 @@ def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[T log_ratio_l = -loss_l + ref_loss_l psi_l = beta * log_ratio_l # [batch_size] - print((w_theta_max * psi_w - w_theta_max * psi_l).mean()) - # Final SDPO loss computation logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size] sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size] diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 626038c0..3e0d7d95 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -502,9 +502,9 @@ def get_noisy_model_input_and_timestep( sigma = torch.randn(bsz, device=device) sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling sigma = sigma.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size sigma = time_shift(mu, 1.0, sigma) - timestep = sigma * num_timesteps + timestep = noise_scheduler._sigma_to_t(sigma) else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly diff --git a/library/train_util.py b/library/train_util.py index 10972f63..240fa039 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2297,7 +2297,6 @@ class DreamBoothDataset(BaseDataset): rejected_image_info.resize_interpolation = resize_interpolation info = ImageSetInfo([chosen_image_info, rejected_image_info]) - print(chosen_image_info.image_size, rejected_image_info.image_size) else: info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path) info.resize_interpolation = ( diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py index 66e22e5c..37229396 100644 --- a/tests/library/test_flux_train_utils.py +++ b/tests/library/test_flux_train_utils.py @@ -13,6 +13,9 @@ class MockNoiseScheduler: self.config.num_train_timesteps = num_train_timesteps self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + # Create fixtures for commonly used objects @pytest.fixture