diff --git a/library/sd3_train_utils.py b/library/sd3_train_utils.py index 66034210..24591219 100644 --- a/library/sd3_train_utils.py +++ b/library/sd3_train_utils.py @@ -1,4 +1,5 @@ import argparse +import glob import math import os from typing import List, Optional, Tuple, Union @@ -282,12 +283,26 @@ def sample_images(*args, **kwargs): class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" - def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) + self.vae = None + + def set_vae(self, vae: sd3_models.SDVAE): self.vae = vae - def get_latents_npz_path(self, absolute_path: str): - return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: + npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX) + if len(npz_file) == 0: + return None, None + w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x") + return int(w), int(h) + + def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str: + return ( + os.path.splitext(absolute_path)[0] + + f"_{image_size[0]:04d}x{image_size[1]:04d}" + + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX + ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): if not self.cache_to_disk: @@ -331,24 +346,24 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) with torch.no_grad(): - latents = self.vae.encode(img_tensor).to("cpu") + latents_tensors = self.vae.encode(img_tensor).to("cpu") if flip_aug: img_tensor = torch.flip(img_tensor, dims=[3]) with torch.no_grad(): flipped_latents = self.vae.encode(img_tensor).to("cpu") else: - flipped_latents = [None] * len(latents) + flipped_latents = [None] * len(latents_tensors) + + # for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks): + for i in range(len(image_infos)): + info = image_infos[i] + latents = latents_tensors[i] + flipped_latent = flipped_latents[i] + alpha_mask = alpha_masks[i] + original_size = original_sizes[i] + crop_ltrb = crop_ltrbs[i] - for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): if self.cache_to_disk: - # save_latents_to_disk( - # info.latents_npz, - # latent, - # info.latents_original_size, - # info.latents_crop_ltrb, - # flipped_latent, - # alpha_mask, - # ) kwargs = {} if flipped_latent is not None: kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() @@ -357,12 +372,12 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): np.savez( info.latents_npz, latents=latents.float().cpu().numpy(), - original_size=np.array(original_sizes), - crop_ltrb=np.array(crop_ltrbs), + original_size=np.array(original_size), + crop_ltrb=np.array(crop_ltrb), **kwargs, ) else: - info.latents = latent + info.latents = latents if flip_aug: info.latents_flipped = flipped_latent info.alpha_mask = alpha_mask diff --git a/library/train_util.py b/library/train_util.py index 8444827d..9db226ea 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -360,11 +360,23 @@ class AugHelper: class LatentsCachingStrategy: + _strategy = None # strategy instance: actual strategy class + def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: self._cache_to_disk = cache_to_disk self._batch_size = batch_size self.skip_disk_cache_validity_check = skip_disk_cache_validity_check + @classmethod + def set_strategy(cls, strategy): + if cls._strategy is not None: + raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set") + cls._strategy = strategy + + @classmethod + def get_strategy(cls) -> Optional["LatentsCachingStrategy"]: + return cls._strategy + @property def cache_to_disk(self): return self._cache_to_disk @@ -373,10 +385,15 @@ class LatentsCachingStrategy: def batch_size(self): return self._batch_size - def get_latents_npz_path(self, absolute_path: str): + def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]: raise NotImplementedError - def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): + def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str: + raise NotImplementedError + + def is_disk_cached_latents_expected( + self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + ) -> bool: raise NotImplementedError def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -1034,7 +1051,7 @@ class BaseDataset(torch.utils.data.Dataset): # check disk cache exists and size of latents if caching_strategy.cache_to_disk: # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix - info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) + info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size) if not is_main_process: # prepare for multi-gpu, only store to info continue @@ -1730,6 +1747,18 @@ class DreamBoothDataset(BaseDataset): img_paths = glob_images(subset.image_dir, "*") sizes = [None] * len(img_paths) + # new caching: get image size from cache files + strategy = LatentsCachingStrategy.get_strategy() + if strategy is not None: + logger.info("get image size from cache files") + size_set_count = 0 + for i, img_path in enumerate(tqdm(img_paths)): + w, h = strategy.get_image_size_from_image_absolute_path(img_path) + if w is not None and h is not None: + sizes[i] = [w, h] + size_set_count += 1 + logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}") + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") if use_cached_info_for_subset: @@ -2807,12 +2836,12 @@ def cache_batch_text_encoder_outputs_sd3( b_lg_out = b_lg_out.detach() b_t5_out = b_t5_out.detach() b_pool = b_pool.detach() - + for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): # debug: NaN check if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") - + if cache_to_disk: save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) else: diff --git a/sd3_train.py b/sd3_train.py index 30d994c7..e2f622e4 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -91,6 +91,15 @@ def train(args): # load tokenizer sd3_tokenizer = sd3_models.SD3Tokenizer() + # prepare caching strategy + if args.new_caching: + latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy( + args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check + ) + else: + latents_caching_strategy = None + train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy) + # データセットを準備する if args.dataset_class is None: blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) @@ -217,10 +226,8 @@ def train(args): file_suffix="_sd3.npz", ) else: - strategy = sd3_train_utils.Sd3LatentsCachingStrategy( - vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check - ) - train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy) + latents_caching_strategy.set_vae(vae) + train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy) vae.to("cpu") # if no sampling, vae can be deleted clean_memory_on_device(accelerator.device)