Implement pseudo Huber loss for Flux and SD3

This commit is contained in:
recris
2024-11-27 18:11:51 +00:00
parent 2a61fc0784
commit 420a180d93
15 changed files with 76 additions and 61 deletions

View File

@@ -378,7 +378,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, None, weighting
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss