mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix sigmas/timesteps
This commit is contained in:
@@ -966,8 +966,6 @@ def sdpo_loss(loss: Tensor, ref_loss: Tensor, beta=0.02, epsilon=0.1) -> tuple[T
|
||||
log_ratio_l = -loss_l + ref_loss_l
|
||||
psi_l = beta * log_ratio_l # [batch_size]
|
||||
|
||||
print((w_theta_max * psi_w - w_theta_max * psi_l).mean())
|
||||
|
||||
# Final SDPO loss computation
|
||||
logits = w_theta_max * psi_w - w_theta_max * psi_l # [batch_size]
|
||||
sigmoid_loss = -torch.log(torch.sigmoid(logits)) # [batch_size]
|
||||
|
||||
@@ -502,9 +502,9 @@ def get_noisy_model_input_and_timestep(
|
||||
sigma = torch.randn(bsz, device=device)
|
||||
sigma = sigma * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigma = sigma.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigma = time_shift(mu, 1.0, sigma)
|
||||
timestep = sigma * num_timesteps
|
||||
timestep = noise_scheduler._sigma_to_t(sigma)
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
|
||||
@@ -2297,7 +2297,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
rejected_image_info.resize_interpolation = resize_interpolation
|
||||
|
||||
info = ImageSetInfo([chosen_image_info, rejected_image_info])
|
||||
print(chosen_image_info.image_size, rejected_image_info.image_size)
|
||||
else:
|
||||
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
|
||||
info.resize_interpolation = (
|
||||
|
||||
@@ -13,6 +13,9 @@ class MockNoiseScheduler:
|
||||
self.config.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
|
||||
# Create fixtures for commonly used objects
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user