This commit is contained in:
Dave Lage
2026-03-04 06:45:29 +01:00
committed by GitHub
8 changed files with 782 additions and 60 deletions

View File

@@ -327,7 +327,7 @@ class NetworkTrainer:
return noise_pred, target, timesteps, None
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor:
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
if args.scale_v_pred_loss_like_noise_pred:
@@ -464,7 +464,7 @@ class NetworkTrainer:
is_train=is_train,
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
if weighting is not None:
loss = loss * weighting
@@ -475,7 +475,7 @@ class NetworkTrainer:
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler, latents)
return loss.mean()