add vae_batch_size option for faster caching

This commit is contained in:
Kohya S
2023-03-21 18:15:57 +09:00
parent cca3804503
commit 1816ac3271
5 changed files with 32 additions and 19 deletions

View File

@@ -138,7 +138,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -675,10 +675,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae): def cache_latents(self, vae, vae_batch_size=1):
# TODO ここを高速化した # ちょっと速くした
print("caching latents.") print("caching latents.")
for info in tqdm(self.image_data.values()): infos = []
for info in self.image_data.values():
subset = self.image_to_subset[info.image_key] subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: if info.latents_npz is not None:
@@ -689,18 +690,29 @@ class BaseDataset(torch.utils.data.Dataset):
info.latents_flipped = torch.FloatTensor(info.latents_flipped) info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue continue
image = self.load_image(info.absolute_path) infos.append(info)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
img_tensor = self.image_transforms(image) for i in tqdm(range(0, len(infos), vae_batch_size), smoothing=1, total=len(infos) // vae_batch_size):
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) batch = infos[i : i + vae_batch_size]
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") images = []
for info in batch:
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
image = self.image_transforms(image)
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent
if subset.flip_aug: if subset.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy img_tensors = torch.flip(img_tensors, dims=[3])
img_tensor = self.image_transforms(image) latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype) for info, latent in zip(batch, latents):
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu") info.latents_flipped = latent
def get_image_size(self, image_path): def get_image_size(self, image_path):
image = Image.open(image_path) image = Image.open(image_path)
@@ -1241,10 +1253,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets: # for dataset in self.datasets:
# dataset.make_buckets() # dataset.make_buckets()
def cache_latents(self, vae): def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.cache_latents(vae) dataset.cache_latents(vae, vae_batch_size)
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -1990,6 +2002,7 @@ def add_dataset_arguments(
action="store_true", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可", help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可",
) )
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument( parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
) )

View File

@@ -114,7 +114,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -139,7 +139,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -228,7 +228,7 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae) train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()