diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 0389da38..019c737a 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -11,6 +11,7 @@ import cv2 import torch from library.device_utils import init_ipex, get_preferred_device + init_ipex() from torchvision import transforms @@ -18,8 +19,10 @@ from torchvision import transforms import library.model_util as model_util import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) DEVICE = get_preferred_device() @@ -89,7 +92,9 @@ def main(args): # bucketのサイズを計算する max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + assert ( + len(max_reso) == 2 + ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" bucket_manager = train_util.BucketManager( args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps @@ -107,7 +112,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument( + "--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", @@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser: help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument( "--full_path", @@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser: help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", ) parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + "--flip_aug", + action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する", + ) + parser.add_argument( + "--alpha_mask", + type=str, + default="", + help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する", ) parser.add_argument( "--skip_existing", diff --git a/library/config_util.py b/library/config_util.py index 82baab83..964270db 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -214,11 +214,13 @@ class ConfigSanitizer: DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, "is_reg": bool, + "alpha_mask": bool, } # FT means FineTuning FT_SUBSET_DISTINCT_SCHEMA = { Required("metadata_file"): str, "image_dir": str, + "alpha_mask": bool, } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index fad12740..af5813a1 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, mask_image): - # mask image is -1 to 1. we need to convert it to 0 to 1 - # mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel - mask_image = mask_image.to(dtype=loss.dtype) +def apply_masked_loss(loss, batch): + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype) + else: + return loss # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 loss = loss * mask_image return loss diff --git a/library/train_util.py b/library/train_util.py index 6cf28590..e7a50f04 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,9 +159,7 @@ class ImageInfo: self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None - self.alpha_mask: Optional[torch.Tensor] = None - self.alpha_mask_flipped: Optional[torch.Tensor] = None - self.use_alpha_mask: bool = False + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime class BucketManager: @@ -364,6 +362,7 @@ class BaseSubset: def __init__( self, image_dir: Optional[str], + alpha_mask: Optional[bool], num_repeats: int, shuffle_caption: bool, caption_separator: str, @@ -382,9 +381,9 @@ class BaseSubset: caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], - alpha_mask: bool, ) -> None: self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator @@ -407,8 +406,6 @@ class BaseSubset: self.img_count = 0 - self.alpha_mask = alpha_mask - class DreamBoothSubset(BaseSubset): def __init__( @@ -418,6 +415,7 @@ class DreamBoothSubset(BaseSubset): class_tokens: Optional[str], caption_extension: str, cache_info: bool, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -441,6 +439,7 @@ class DreamBoothSubset(BaseSubset): super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -479,6 +478,7 @@ class FineTuningSubset(BaseSubset): self, image_dir, metadata_file: str, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator, @@ -502,6 +502,7 @@ class FineTuningSubset(BaseSubset): super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -921,7 +922,7 @@ class BaseDataset(torch.utils.data.Dataset): logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -991,8 +992,6 @@ class BaseDataset(torch.utils.data.Dataset): for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] - info.use_alpha_mask = subset.alpha_mask - if info.latents_npz is not None: # fine tuning dataset continue @@ -1002,7 +1001,9 @@ class BaseDataset(torch.utils.data.Dataset): if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) if cache_available: # do not add to batch continue @@ -1028,7 +1029,7 @@ class BaseDataset(torch.utils.data.Dataset): # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1202,18 +1203,15 @@ class BaseDataset(torch.utils.data.Dataset): alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped - alpha_mask = image_info.alpha_mask_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk( - image_info.latents_npz - ) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents - alpha_mask = flipped_alpha_mask + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem del flipped_latents - del flipped_alpha_mask latents = torch.FloatTensor(latents) if alpha_mask is not None: alpha_mask = torch.FloatTensor(alpha_mask) @@ -1255,23 +1253,28 @@ class BaseDataset(torch.utils.data.Dataset): # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: - img = aug(image=img)["image"] + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem if subset.alpha_mask: if img.shape[2] == 4: - alpha_mask = img[:, :, 3] # [W,H] + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 else: - alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] - alpha_mask = transforms.ToTensor()(alpha_mask) + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: alpha_mask = None + img = img[:, :, :3] # remove alpha channel latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img images.append(image) latents_list.append(latents) @@ -1361,6 +1364,23 @@ class BaseDataset(torch.utils.data.Dataset): example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) + else: + example["alpha_masks"] = torch.stack(alpha_mask_list) + if images[0] is not None: images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() @@ -1378,8 +1398,6 @@ class BaseDataset(torch.utils.data.Dataset): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) - example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None - if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1393,6 +1411,7 @@ class BaseDataset(torch.utils.data.Dataset): resized_sizes = [] bucket_reso = None flip_aug = None + alpha_mask = None random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1401,10 +1420,13 @@ class BaseDataset(torch.utils.data.Dataset): if flip_aug is None: flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask random_crop = subset.random_crop bucket_reso = image_info.bucket_reso else: + # TODO そもそも混在してても動くようにしたほうがいい assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" @@ -1441,6 +1463,7 @@ class BaseDataset(torch.utils.data.Dataset): example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask example["random_crop"] = random_crop example["bucket_reso"] = bucket_reso return example @@ -2149,7 +2172,7 @@ class DatasetGroup(torch.utils.data.ConcatDataset): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): @@ -2167,6 +2190,12 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != reso: # HxW + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2177,14 +2206,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[ - Optional[torch.Tensor], - Optional[List[int]], - Optional[List[int]], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], -]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2194,20 +2216,15 @@ def load_latents_from_disk( crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk( - npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None -): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - if flipped_alpha_mask is not None: - kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy() + kwargs["alpha_mask"] = alpha_mask # ndarray np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2398,10 +2415,11 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: def load_image(image_path, alpha=False): image = Image.open(image_path) - if not image.mode == "RGB": - if alpha: + if alpha: + if not image.mode == "RGBA": image = image.convert("RGBA") - else: + else: + if not image.mode == "RGB": image = image.convert("RGB") img = np.array(image, np.uint8) return img @@ -2441,7 +2459,7 @@ def trim_and_resize_if_required( def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2453,49 +2471,43 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] - alpha_masks = [] + alpha_masks: List[np.ndarray] = [] for info in image_infos: - image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - if info.use_alpha_mask: - if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [W,H] - image = image[:, :, :3] - else: - alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] - alpha_masks.append(transforms.ToTensor()(alpha_mask)) - image = IMAGE_TRANSFORMS(image) - images.append(image) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb + if use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 + else: + alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists + image = IMAGE_TRANSFORMS(image) + images.append(image) + img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - if info.use_alpha_mask: - alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu") - else: - alpha_masks = [None] * len(image_infos) - flipped_alpha_masks = [None] * len(image_infos) - if flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) with torch.no_grad(): flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - if info.use_alpha_mask: - flipped_alpha_masks = torch.flip(alpha_masks, dims=[3]) else: flipped_latents = [None] * len(latents) - flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip( - image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks - ): + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") @@ -2508,15 +2520,12 @@ def cache_batch_latents( info.latents_crop_ltrb, flipped_latent, alpha_mask, - flipped_alpha_mask, ) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - info.alpha_mask_flipped = flipped_alpha_mask if not HIGH_VRAM: clean_memory_on_device(vae.device) diff --git a/sdxl_train.py b/sdxl_train.py index dcd06766..9e20c60c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -711,10 +711,8 @@ def train(args): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f..b7c88121 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -17,10 +17,13 @@ from library.config_util import ( BlueprintGenerator, ) from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -107,7 +110,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: else: _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) @@ -136,6 +139,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: b_size = len(batch["images"]) vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size flip_aug = batch["flip_aug"] + alpha_mask = batch["alpha_mask"] random_crop = batch["random_crop"] bucket_reso = batch["bucket_reso"] @@ -154,14 +158,16 @@ def cache_to_disk(args: argparse.Namespace) -> None: image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + if train_util.is_disk_cached_latents_is_expected( + image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask + ): logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") diff --git a/train_db.py b/train_db.py index c4690000..39d8ea6e 100644 --- a/train_db.py +++ b/train_db.py @@ -359,10 +359,8 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_network.py b/train_network.py index cd1677ad..b272a6e1 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,9 @@ class NetworkTrainer: if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -902,10 +904,8 @@ class NetworkTrainer: loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index a9c2a109..ade077c3 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -589,10 +589,8 @@ class TextualInversionTrainer: target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 959839cb..efb59137 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -474,10 +474,8 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight