simplify and update alpha mask to work with various cases

This commit is contained in:
Kohya S
2024-05-19 21:26:18 +09:00
parent f2dd43e198
commit da6fea3d97
10 changed files with 140 additions and 105 deletions

View File

@@ -474,10 +474,8 @@ def train(args):
target = noise
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight