Add custom train functions test for loss modifications

This commit is contained in:
rockerBOO
2025-03-20 16:58:49 -04:00
parent 8d5a183cc5
commit 3ffd3b84a5
2 changed files with 252 additions and 22 deletions

View File

@@ -87,7 +87,7 @@ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
def apply_snr_weight(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, gamma: Number, v_prediction=False, image_size=None):
# Get the appropriate SNR values based on timesteps and potentially image size
if hasattr(noise_scheduler, "get_snr_for_timestep"):
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
@@ -109,7 +109,7 @@ def scale_v_prediction_loss_like_noise_prediction(loss: torch.Tensor, timesteps:
def get_snr_scale(timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, image_size=None):
# Get SNR values with image_size consideration
if hasattr(noise_scheduler, "get_snr_for_timestep"):
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
@@ -131,27 +131,30 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
# Check if we have SNR values available
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
return loss
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
return loss
# Get SNR values with image_size consideration
if hasattr(noise_scheduler, "get_snr_for_timestep"):
snr_t = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
# Cap the SNR to avoid numerical issues
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
# Apply weighting based on prediction type
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss
if not callable(noise_scheduler.get_snr_for_timestep):
return loss
# Get SNR values with image_size consideration
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
else:
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
# Cap the SNR to avoid numerical issues
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
# Apply weighting based on prediction type
if v_prediction:
weight = 1 / (snr_t + 1)
else:
weight = 1 / torch.sqrt(snr_t)
loss = weight * loss
return loss