画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)

* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
This commit is contained in:
u-haru
2024-05-19 19:07:25 +09:00
committed by GitHub
parent febc5c59fa
commit db6752901f
10 changed files with 105 additions and 129 deletions

View File

@@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
* Text Encoderに関連するLoRAモジュールに、通常の学習率--learning_rateオプションで指定とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率5e-5などにしたほうが良い、という話もあるようです。
* `--network_args`
* 複数の引数を指定できます。後述します。
* `--alpha_mask`
* 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
`--network_train_unet_only``--network_train_text_encoder_only` の両方とも未指定時デフォルトはText EncoderとU-Netの両方のLoRAモジュールを有効にします。

View File

@@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中
* 当在Text Encoder相关的LoRA模块中使用与常规学习率`--learning_rate`选项指定不同的学习率时应指定此选项。可能最好将Text Encoder的学习率稍微降低例如5e-5
* `--network_args`
* 可以指定多个参数。将在下面详细说明。
* `--alpha_mask`
* 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223)
当未指定`--network_train_unet_only``--network_train_text_encoder_only`默认情况将启用Text Encoder和U-Net的两个LoRA模块。

View File

@@ -78,6 +78,7 @@ class BaseSubsetParams:
caption_tag_dropout_rate: float = 0.0
token_warmup_min: int = 1
token_warmup_step: float = 0
alpha_mask: bool = False
@dataclass
@@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
random_crop: {subset.random_crop}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask},
"""
),
" ",

View File

@@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise
def apply_masked_loss(loss, batch):
def apply_masked_loss(loss, mask_image):
# mask image is -1 to 1. we need to convert it to 0 to 1
mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
# mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel
mask_image = mask_image.to(dtype=loss.dtype)
# resize to the same size as the loss
mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area")

View File

@@ -159,6 +159,9 @@ class ImageInfo:
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None
self.alpha_mask_flipped: Optional[torch.Tensor] = None
self.use_alpha_mask: bool = False
class BucketManager:
@@ -379,6 +382,7 @@ class BaseSubset:
caption_suffix: Optional[str],
token_warmup_min: int,
token_warmup_step: Union[float, int],
alpha_mask: bool,
) -> None:
self.image_dir = image_dir
self.num_repeats = num_repeats
@@ -403,6 +407,7 @@ class BaseSubset:
self.img_count = 0
self.alpha_mask = alpha_mask
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -412,47 +417,13 @@ class DreamBoothSubset(BaseSubset):
class_tokens: Optional[str],
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
**kwargs,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
**kwargs,
)
self.is_reg = is_reg
@@ -473,47 +444,13 @@ class FineTuningSubset(BaseSubset):
self,
image_dir,
metadata_file: str,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
**kwargs,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
**kwargs,
)
self.metadata_file = metadata_file
@@ -531,47 +468,13 @@ class ControlNetSubset(BaseSubset):
conditioning_data_dir: str,
caption_extension: str,
cache_info: bool,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
**kwargs,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
super().__init__(
image_dir,
num_repeats,
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
secondary_separator,
enable_wildcard,
color_aug,
flip_aug,
face_crop_aug_range,
random_crop,
caption_dropout_rate,
caption_dropout_every_n_epochs,
caption_tag_dropout_rate,
caption_prefix,
caption_suffix,
token_warmup_min,
token_warmup_step,
**kwargs,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -985,6 +888,8 @@ class BaseDataset(torch.utils.data.Dataset):
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
info.use_alpha_mask = subset.alpha_mask
if info.latents_npz is not None: # fine tuning dataset
continue
@@ -1088,8 +993,8 @@ class BaseDataset(torch.utils.data.Dataset):
def get_image_size(self, image_path):
return imagesize.get(image_path)
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = load_image(image_path)
def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False):
img = load_image(image_path, alpha_mask)
face_cx = face_cy = face_w = face_h = 0
if subset.face_crop_aug_range is not None:
@@ -1166,6 +1071,7 @@ class BaseDataset(torch.utils.data.Dataset):
input_ids_list = []
input_ids2_list = []
latents_list = []
alpha_mask_list = []
images = []
original_sizes_hw = []
crop_top_lefts = []
@@ -1190,21 +1096,27 @@ class BaseDataset(torch.utils.data.Dataset):
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
alpha_mask = image_info.alpha_mask
else:
latents = image_info.latents_flipped
alpha_mask = image_info.alpha_mask_flipped
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz)
if flipped:
latents = flipped_latents
alpha_mask = flipped_alpha_mask
del flipped_latents
del flipped_alpha_mask
latents = torch.FloatTensor(latents)
if alpha_mask is not None:
alpha_mask = torch.FloatTensor(alpha_mask)
image = None
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path)
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask)
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
@@ -1241,11 +1153,22 @@ class BaseDataset(torch.utils.data.Dataset):
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem
if subset.alpha_mask:
if img.shape[2] == 4:
alpha_mask = img[:, :, 3] # [W,H]
else:
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
alpha_mask = transforms.ToTensor()(alpha_mask)
else:
alpha_mask = None
img = img[:, :, :3] # remove alpha channel
latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
images.append(image)
latents_list.append(latents)
alpha_mask_list.append(alpha_mask)
target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)
@@ -1348,6 +1271,8 @@ class BaseDataset(torch.utils.data.Dataset):
example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions))
example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None
if self.debug_dataset:
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
@@ -2145,7 +2070,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2154,13 +2079,19 @@ def load_latents_from_disk(
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
return latents, original_size, crop_ltrb, flipped_latents
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None):
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
if flipped_alpha_mask is not None:
kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
@@ -2349,17 +2280,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
return train_dataset_group
def load_image(image_path):
def load_image(image_path, alpha=False):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
if alpha:
image = image.convert("RGBA")
else:
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize
@@ -2403,10 +2337,18 @@ def cache_batch_latents(
latents_original_size and latents_crop_ltrb are also set
"""
images = []
alpha_masks = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
if info.use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [W,H]
image = image[:, :, :3]
else:
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
alpha_masks.append(transforms.ToTensor()(alpha_mask))
image = IMAGE_TRANSFORMS(image)
images.append(image)
@@ -2419,25 +2361,37 @@ def cache_batch_latents(
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if info.use_alpha_mask:
alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu")
else:
alpha_masks = [None] * len(image_infos)
flipped_alpha_masks = [None] * len(image_infos)
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if info.use_alpha_mask:
flipped_alpha_masks = torch.flip(alpha_masks, dims=[3])
else:
flipped_latents = [None] * len(latents)
flipped_alpha_masks = [None] * len(image_infos)
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
info.alpha_mask_flipped = flipped_alpha_mask
if not HIGH_VRAM:
clean_memory_on_device(vae.device)
@@ -3683,6 +3637,11 @@ def add_dataset_arguments(
default=0,
help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / NN<1ならN*max_train_stepsステップでタグ長が最大になる。デフォルトは0最初から最大",
)
parser.add_argument(
"--alpha_mask",
action="store_true",
help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する",
)
parser.add_argument(
"--dataset_class",

View File

@@ -712,7 +712,9 @@ def train(args):
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
loss = loss.mean([1, 2, 3])
if args.min_snr_gamma:

View File

@@ -360,7 +360,9 @@ def train(args):
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -903,7 +903,9 @@ class NetworkTrainer:
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -590,7 +590,9 @@ class TextualInversionTrainer:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight

View File

@@ -475,7 +475,9 @@ def train(args):
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1))
if "alpha_mask" in batch and batch["alpha_mask"] is not None:
loss = apply_masked_loss(loss, batch["alpha_mask"])
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight