From 2e9f7b5f9135dd9a970bac863907f63adb53f943 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Apr 2023 23:10:39 +0900 Subject: [PATCH] cache latents to disk in dreambooth method --- fine_tune.py | 4 ++- library/train_util.py | 62 ++++++++++++++++++++++++++++------ train_db.py | 4 ++- train_network.py | 4 ++- train_textual_inversion.py | 4 ++- train_textual_inversion_XTI.py | 4 ++- 6 files changed, 67 insertions(+), 15 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 2157de98..47454670 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -142,12 +142,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # 学習を準備する:モデルを適切な状態にする training_models = [] if args.gradient_checkpointing: diff --git a/library/train_util.py b/library/train_util.py index 56eef81f..6b398707 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -722,7 +722,7 @@ class BaseDataset(torch.utils.data.Dataset): def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - def cache_latents(self, vae, vae_batch_size=1): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): # ちょっと速くした print("caching latents.") @@ -740,11 +740,38 @@ class BaseDataset(torch.utils.data.Dataset): if info.latents_npz is not None: info.latents = self.load_latents_from_npz(info, False) info.latents = torch.FloatTensor(info.latents) - info.latents_flipped = self.load_latents_from_npz(info, True) # might be None + + # might be None, but that's ok because check is done in dataset + info.latents_flipped = self.load_latents_from_npz(info, True) if info.latents_flipped is not None: info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue + # check disk cache exists and size of latents + if cache_to_disk: + # TODO: refactor to unify with FineTuningDataset + info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" + info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" + if not is_main_process: + continue + + cache_available = False + expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 + if os.path.exists(info.latents_npz): + cached_latents = np.load(info.latents_npz) + if cached_latents["latents"].shape[1:3] == expected_latents_size: + cache_available = True + + if subset.flip_aug: + cache_available = False + if os.path.exists(info.latents_npz_flipped): + cached_latents_flipped = np.load(info.latents_npz_flipped) + if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size: + cache_available = True + + if cache_available: + continue + # if last member of batch has different resolution, flush the batch if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: batches.append(batch) @@ -760,6 +787,9 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) > 0: batches.append(batch) + if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only + return + # iterate batches for batch in tqdm(batches, smoothing=1, total=len(batches)): images = [] @@ -773,14 +803,21 @@ class BaseDataset(torch.utils.data.Dataset): img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): - info.latents = latent + if cache_to_disk: + np.savez(info.latents_npz, latent.float().numpy()) + else: + info.latents = latent if subset.flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") for info, latent in zip(batch, latents): - info.latents_flipped = latent + if cache_to_disk: + np.savez(info.latents_npz_flipped, latent.float().numpy()) + else: + info.latents_flipped = latent def get_image_size(self, image_path): image = Image.open(image_path) @@ -873,10 +910,10 @@ class BaseDataset(torch.utils.data.Dataset): loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) # image/latentsを処理する - if image_info.latents is not None: + if image_info.latents is not None: # cache_latents=Trueの場合 latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped image = None - elif image_info.latents_npz is not None: + elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) latents = torch.FloatTensor(latents) image = None @@ -1340,10 +1377,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): for dataset in self.datasets: dataset.enable_XTI(*args, **kwargs) - def cache_latents(self, vae, vae_batch_size=1): + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): for i, dataset in enumerate(self.datasets): print(f"[Dataset {i}]") - dataset.cache_latents(vae, vae_batch_size) + dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -2144,9 +2181,14 @@ def add_dataset_arguments( parser.add_argument( "--cache_latents", action="store_true", - help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)", + help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ", ) parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") + parser.add_argument( + "--cache_latents_to_disk", + action="store_true", + help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)", + ) parser.add_argument( "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" ) @@ -3203,4 +3245,4 @@ class collater_class: # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) - return examples[0] \ No newline at end of file + return examples[0] diff --git a/train_db.py b/train_db.py index e72dc889..eddf8f68 100644 --- a/train_db.py +++ b/train_db.py @@ -117,12 +117,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # 学習を準備する:モデルを適切な状態にする train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 unet.requires_grad_(True) # 念のため追加 diff --git a/train_network.py b/train_network.py index ef630969..fb3d6130 100644 --- a/train_network.py +++ b/train_network.py @@ -172,12 +172,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + # prepare network import sys diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 98639345..88ddebdd 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -233,12 +233,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index db46ad1b..d302491e 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -267,12 +267,14 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae, args.vae_batch_size) + train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() + accelerator.wait_for_everyone() + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable()