画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)

* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
This commit is contained in:
u-haru
2024-05-19 19:07:25 +09:00
committed by GitHub
parent febc5c59fa
commit db6752901f
10 changed files with 105 additions and 129 deletions

View File

@@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
def apply_masked_loss(loss, mask_image):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image.to(dtype=loss.dtype)
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")