diff --git a/flux_train.py b/flux_train.py index f6e43b27..cfe14885 100644 --- a/flux_train.py +++ b/flux_train.py @@ -667,7 +667,7 @@ def train(args): # calculate loss loss = train_util.conditional_loss( - args, noise_pred.float(), target.float(), timesteps, "none", noise_scheduler + args, model_pred.float(), target.float(), timesteps, "none", noise_scheduler ) if weighting is not None: loss = loss * weighting