diff --git a/train_network.py b/train_network.py index 6cebf5fc..2f8797d2 100644 --- a/train_network.py +++ b/train_network.py @@ -470,7 +470,7 @@ class NetworkTrainer: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights