simplify and update alpha mask to work with various cases

This commit is contained in:
Kohya S
2024-05-19 21:26:18 +09:00
parent f2dd43e198
commit da6fea3d97
10 changed files with 140 additions and 105 deletions

View File

@@ -711,10 +711,8 @@ def train(args):
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma: