mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Merge 4c8ebf7293 into d633b51126
This commit is contained in:
@@ -327,7 +327,7 @@ class NetworkTrainer:
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler, latents: Optional[torch.Tensor]) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -464,7 +464,7 @@ class NetworkTrainer:
|
||||
is_train=is_train,
|
||||
)
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if weighting is not None:
|
||||
loss = loss * weighting
|
||||
@@ -475,7 +475,7 @@ class NetworkTrainer:
|
||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler)
|
||||
loss = self.post_process_loss(loss, args, timesteps, noise_scheduler, latents)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user