From 1816ac327174c5750511e54893de736974561147 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 21 Mar 2023 18:15:57 +0900 Subject: [PATCH] add vae_batch_size option for faster caching --- fine_tune.py | 2 +- library/train_util.py | 43 +++++++++++++++++++++++++------------- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index d927bd73..0df6bc62 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -138,7 +138,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae) + train_dataset_group.cache_latents(vae, args.vae_batch_size) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/library/train_util.py b/library/train_util.py index 04a199ac..8dbd4dbb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -675,10 +675,11 @@ class BaseDataset(torch.utils.data.Dataset): def is_latent_cacheable(self): return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) - def cache_latents(self, vae): - # TODO ここを高速化したい + def cache_latents(self, vae, vae_batch_size=1): + # ちょっと速くした print("caching latents.") - for info in tqdm(self.image_data.values()): + infos = [] + for info in self.image_data.values(): subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: @@ -689,18 +690,29 @@ class BaseDataset(torch.utils.data.Dataset): info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue - image = self.load_image(info.absolute_path) - image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + infos.append(info) - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + for i in tqdm(range(0, len(infos), vae_batch_size), smoothing=1, total=len(infos) // vae_batch_size): + batch = infos[i : i + vae_batch_size] + images = [] + for info in batch: + image = self.load_image(info.absolute_path) + image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size) + image = self.image_transforms(image) + images.append(image) + + img_tensors = torch.stack(images, dim=0) + img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) + + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): + info.latents = latent if subset.flip_aug: - image = image[:, ::-1].copy() # cannot convert to Tensor without copy - img_tensor = self.image_transforms(image) - img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) - info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") + img_tensors = torch.flip(img_tensors, dims=[3]) + latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + for info, latent in zip(batch, latents): + info.latents_flipped = latent def get_image_size(self, image_path): image = Image.open(image_path) @@ -1200,7 +1212,7 @@ class FineTuningDataset(BaseDataset): # if not full path, check image_dir. if image_dir is None, return None if subset.image_dir is None: return None, None - + # image_key is relative path npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz") npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz") @@ -1241,10 +1253,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset): # for dataset in self.datasets: # dataset.make_buckets() - def cache_latents(self, vae): + def cache_latents(self, vae, vae_batch_size=1): for i, dataset in enumerate(self.datasets): print(f"[Dataset {i}]") - dataset.cache_latents(vae) + dataset.cache_latents(vae, vae_batch_size) def is_latent_cacheable(self) -> bool: return all([dataset.is_latent_cacheable() for dataset in self.datasets]) @@ -1990,6 +2002,7 @@ def add_dataset_arguments( action="store_true", help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)", ) + parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") parser.add_argument( "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" ) diff --git a/train_db.py b/train_db.py index 81aeda19..779e9d00 100644 --- a/train_db.py +++ b/train_db.py @@ -114,7 +114,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae) + train_dataset_group.cache_latents(vae, args.vae_batch_size) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/train_network.py b/train_network.py index 7f910df4..db6dfc2f 100644 --- a/train_network.py +++ b/train_network.py @@ -139,7 +139,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae) + train_dataset_group.cache_latents(vae, args.vae_batch_size) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e4ab7b5c..57ac6ee6 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -228,7 +228,7 @@ def train(args): vae.requires_grad_(False) vae.eval() with torch.no_grad(): - train_dataset_group.cache_latents(vae) + train_dataset_group.cache_latents(vae, args.vae_batch_size) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache()