Fix sigmas/timesteps

This commit is contained in:
rockerBOO
2025-06-03 20:55:29 -04:00
parent 415233993a
commit db05136480
4 changed files with 5 additions and 5 deletions

View File

@@ -966,8 +966,6 @@ def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[T
log_ratio_l = -loss_l + ref_loss_l
psi_l = beta * log_ratio_l # [batch_size]
print((w_theta_max * psi_w - w_theta_max * psi_l).mean())
# Final SDPO loss computation
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]

View File

@@ -502,9 +502,9 @@ def get_noisy_model_input_and_timestep(
sigma = torch.randn(bsz, device=device)
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
sigma = sigma.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigma = time_shift(mu, 1.0, sigma)
timestep = sigma * num_timesteps
timestep = noise_scheduler._sigma_to_t(sigma)
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly

View File

@@ -2297,7 +2297,6 @@ class DreamBoothDataset(BaseDataset):
rejected_image_info.resize_interpolation = resize_interpolation
info = ImageSetInfo([chosen_image_info, rejected_image_info])
print(chosen_image_info.image_size, rejected_image_info.image_size)
else:
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = (

View File

@@ -13,6 +13,9 @@ class MockNoiseScheduler:
self.config.num_train_timesteps = num_train_timesteps
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
# Create fixtures for commonly used objects
@pytest.fixture