From 6ab48b09d8e46973d5e5fa47baeae3a464d06d04 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 20 Aug 2024 21:39:43 +0900 Subject: [PATCH] feat: Support multi-resolution training with caching latents to disk --- README.md | 11 +++- library/strategy_base.py | 112 ++++++++++++++++++++++++++------------- library/strategy_flux.py | 11 +++- library/train_util.py | 2 +- 4 files changed, 93 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 3f5c4daa..1d44c9e5 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 20, 2024 (update 3): +__Experimental__ The multi-resolution training is now supported with caching latents to disk. + +The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file). + +See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. + Aug 20, 2024 (update 2): `flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015! Aug 20, 2024: FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution). -The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. +The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details. We will support multi-resolution caching to disk in the near future. @@ -171,7 +178,7 @@ The script can merge multiple LoRA models. If you want to merge multiple LoRA mo ### FLUX.1 Multi-resolution training -You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__ +You can define multiple resolutions in the dataset configuration file. The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution. diff --git a/library/strategy_base.py b/library/strategy_base.py index a99a0829..e7d3a97e 100644 --- a/library/strategy_base.py +++ b/library/strategy_base.py @@ -219,7 +219,13 @@ class LatentsCachingStrategy: raise NotImplementedError def _default_is_disk_cached_latents_expected( - self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool + self, + latents_stride: int, + bucket_reso: Tuple[int, int], + npz_path: str, + flip_aug: bool, + alpha_mask: bool, + multi_resolution: bool = False, ): if not self.cache_to_disk: return False @@ -230,25 +236,17 @@ class LatentsCachingStrategy: expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + # e.g. "_32x64", HxW + key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else "" + try: npz = np.load(npz_path) - if npz["latents"].shape[1:3] != expected_latents_size: + if "latents" + key_reso_suffix not in npz: + return False + if flip_aug and "latents_flipped" + key_reso_suffix not in npz: + return False + if alpha_mask and "alpha_mask" + key_reso_suffix not in npz: return False - - if flip_aug: - if "latents_flipped" not in npz: - return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: - return False - - if alpha_mask: - if "alpha_mask" not in npz: - return False - if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]): - return False - else: - if "alpha_mask" in npz: - return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -257,7 +255,15 @@ class LatentsCachingStrategy: # TODO remove circular dependency for ImageInfo def _default_cache_batch_latents( - self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool + self, + encode_by_vae, + vae_device, + vae_dtype, + image_infos: List, + flip_aug: bool, + alpha_mask: bool, + random_crop: bool, + multi_resolution: bool = False, ): """ Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common. @@ -287,8 +293,13 @@ class LatentsCachingStrategy: original_size = original_sizes[i] crop_ltrb = crop_ltrbs[i] + latents_size = latents.shape[1:3] # H, W + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW + if self.cache_to_disk: - self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask) + self.save_latents_to_disk( + info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix + ) else: info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -298,31 +309,56 @@ class LatentsCachingStrategy: info.alpha_mask = alpha_mask def load_latents_from_disk( - self, npz_path: str + self, npz_path: str, bucket_reso: Tuple[int, int] ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: - npz = np.load(npz_path) - if "latents" not in npz: - raise ValueError(f"error: npz is old format. please re-generate {npz_path}") + """ + for SD/SDXL/SD3.0 + """ + return self._default_load_latents_from_disk(None, npz_path, bucket_reso) - latents = npz["latents"] - original_size = npz["original_size"].tolist() - crop_ltrb = npz["crop_ltrb"].tolist() - flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + def _default_load_latents_from_disk( + self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + if latents_stride is None: + key_reso_suffix = "" + else: + latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H) + key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW + + npz = np.load(npz_path) + if "latents" + key_reso_suffix not in npz: + raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}") + + latents = npz["latents" + key_reso_suffix] + original_size = npz["original_size" + key_reso_suffix].tolist() + crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist() + flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None + alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None return latents, original_size, crop_ltrb, flipped_latents, alpha_mask def save_latents_to_disk( - self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None + self, + npz_path, + latents_tensor, + original_size, + crop_ltrb, + flipped_latents_tensor=None, + alpha_mask=None, + key_reso_suffix="", ): kwargs = {} + + if os.path.exists(npz_path): + # load existing npz and update it + npz = np.load(npz_path) + for key in npz.files: + kwargs[key] = npz[key] + + kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy() + kwargs["original_size" + key_reso_suffix] = np.array(original_size) + kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb) if flipped_latents_tensor is not None: - kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - np.savez( - npz_path, - latents=latents_tensor.float().cpu().numpy(), - original_size=np.array(original_size), - crop_ltrb=np.array(crop_ltrb), - **kwargs, - ) + kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy() + np.savez(npz_path, **kwargs) diff --git a/library/strategy_flux.py b/library/strategy_flux.py index 3880a1e1..5c620f3d 100644 --- a/library/strategy_flux.py +++ b/library/strategy_flux.py @@ -200,7 +200,12 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): ) def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -208,7 +213,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents( + encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True + ) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device) diff --git a/library/train_util.py b/library/train_util.py index f4ac8740..8929c192 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1381,7 +1381,7 @@ class BaseDataset(torch.utils.data.Dataset): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents, original_size, crop_ltrb, flipped_latents, alpha_mask = ( - self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz) + self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso) ) if flipped: latents = flipped_latents