add vae_batch_size option for faster caching

This commit is contained in:
Kohya S
2023-03-21 18:15:57 +09:00
parent cca3804503
commit 1816ac3271
5 changed files with 32 additions and 19 deletions

View File

@@ -138,7 +138,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -675,10 +675,11 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae):
# TODO ここを高速化した
def cache_latents(self, vae, vae_batch_size=1):
# ちょっと速くした
print("caching latents.")
for info in tqdm(self.image_data.values()):
infos = []
for info in self.image_data.values():
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None:
@@ -689,18 +690,29 @@ class BaseDataset(torch.utils.data.Dataset):
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
infos.append(info)
for i in tqdm(range(0, len(infos), vae_batch_size), smoothing=1, total=len(infos) // vae_batch_size):
batch = infos[i : i + vae_batch_size]
images = []
for info in batch:
image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(subset, image, info.bucket_reso, info.resized_size)
image = self.image_transforms(image)
images.append(image)
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents = latent
if subset.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
info.latents_flipped = latent
def get_image_size(self, image_path):
image = Image.open(image_path)
@@ -1241,10 +1253,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets:
# dataset.make_buckets()
def cache_latents(self, vae):
def cache_latents(self, vae, vae_batch_size=1):
for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]")
dataset.cache_latents(vae)
dataset.cache_latents(vae, vae_batch_size)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -1990,6 +2002,7 @@ def add_dataset_arguments(
action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可",
)
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
)

View File

@@ -114,7 +114,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -139,7 +139,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@@ -228,7 +228,7 @@ def train(args):
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae)
train_dataset_group.cache_latents(vae, args.vae_batch_size)
vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()