From f320c1b964bb07a9ee8d5a881f1198e49fd20506 Mon Sep 17 00:00:00 2001 From: kohya-ss <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 9 Feb 2026 12:44:47 +0900 Subject: [PATCH] feat: update loss calculation to support 5d tensor --- train_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 6cebf5fc..2f8797d2 100644 --- a/train_network.py +++ b/train_network.py @@ -470,7 +470,7 @@ class NetworkTrainer: loss = loss * weighting if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) + loss = loss.mean(dim=list(range(1, loss.ndim))) # mean over all dims except batch loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights