simplify and update alpha mask to work with various cases

This commit is contained in:
Kohya S
2024-05-19 21:26:18 +09:00
parent f2dd43e198
commit da6fea3d97
10 changed files with 140 additions and 105 deletions

View File

@@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, mask_image):
# mask image is -1 to 1. we need to convert it to 0 to 1
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image.to(dtype=loss.dtype)
def apply_masked_loss(loss, batch):
if "conditioning_images" in batch:
# conditioning image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image / 2 + 0.5
elif "alpha_masks" in batch and batch["alpha_masks"] is not None:
# alpha mask is 0 to 1
mask_image = batch["alpha_masks"].to(dtype=loss.dtype)
else:
return loss
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")
mask_image = mask_image / 2 + 0.5
loss = loss * mask_image
return loss