cache latents to disk in dreambooth method

This commit is contained in:
Kohya S
2023-04-12 23:10:39 +09:00
parent 5050971ac6
commit 2e9f7b5f91
6 changed files with 67 additions and 15 deletions

View File

@@ -142,12 +142,14 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
training_models = [] training_models = []
if args.gradient_checkpointing: if args.gradient_checkpointing:

View File

@@ -722,7 +722,7 @@ class BaseDataset(torch.utils.data.Dataset):
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
def cache_latents(self, vae, vae_batch_size=1): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# ちょっと速くした # ちょっと速くした
print("caching latents.") print("caching latents.")
@@ -740,11 +740,38 @@ class BaseDataset(torch.utils.data.Dataset):
if info.latents_npz is not None: if info.latents_npz is not None:
info.latents = self.load_latents_from_npz(info, False) info.latents = self.load_latents_from_npz(info, False)
info.latents = torch.FloatTensor(info.latents) info.latents = torch.FloatTensor(info.latents)
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
# might be None, but that's ok because check is done in dataset
info.latents_flipped = self.load_latents_from_npz(info, True)
if info.latents_flipped is not None: if info.latents_flipped is not None:
info.latents_flipped = torch.FloatTensor(info.latents_flipped) info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue continue
# check disk cache exists and size of latents
if cache_to_disk:
# TODO: refactor to unify with FineTuningDataset
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
if not is_main_process:
continue
cache_available = False
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
if os.path.exists(info.latents_npz):
cached_latents = np.load(info.latents_npz)
if cached_latents["latents"].shape[1:3] == expected_latents_size:
cache_available = True
if subset.flip_aug:
cache_available = False
if os.path.exists(info.latents_npz_flipped):
cached_latents_flipped = np.load(info.latents_npz_flipped)
if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size:
cache_available = True
if cache_available:
continue
# if last member of batch has different resolution, flush the batch # if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch) batches.append(batch)
@@ -760,6 +787,9 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0: if len(batch) > 0:
batches.append(batch) batches.append(batch)
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
return
# iterate batches # iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)): for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = [] images = []
@@ -773,13 +803,20 @@ class BaseDataset(torch.utils.data.Dataset):
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents): for info, latent in zip(batch, latents):
if cache_to_disk:
np.savez(info.latents_npz, latent.float().numpy())
else:
info.latents = latent info.latents = latent
if subset.flip_aug: if subset.flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3]) img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents): for info, latent in zip(batch, latents):
if cache_to_disk:
np.savez(info.latents_npz_flipped, latent.float().numpy())
else:
info.latents_flipped = latent info.latents_flipped = latent
def get_image_size(self, image_path): def get_image_size(self, image_path):
@@ -873,10 +910,10 @@ class BaseDataset(torch.utils.data.Dataset):
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0) loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
# image/latentsを処理する # image/latentsを処理する
if image_info.latents is not None: if image_info.latents is not None: # cache_latents=Trueの場合
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
image = None image = None
elif image_info.latents_npz is not None: elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5) latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
latents = torch.FloatTensor(latents) latents = torch.FloatTensor(latents)
image = None image = None
@@ -1340,10 +1377,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
for dataset in self.datasets: for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs) dataset.enable_XTI(*args, **kwargs)
def cache_latents(self, vae, vae_batch_size=1): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
for i, dataset in enumerate(self.datasets): for i, dataset in enumerate(self.datasets):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size) dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -2144,9 +2181,14 @@ def add_dataset_arguments(
parser.add_argument( parser.add_argument(
"--cache_latents", "--cache_latents",
action="store_true", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheするaugmentationは使用不可", help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheするaugmentationは使用不可 ",
) )
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ") parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
parser.add_argument(
"--cache_latents_to_disk",
action="store_true",
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheするaugmentationは使用不可",
)
parser.add_argument( parser.add_argument(
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする" "--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
) )

View File

@@ -117,12 +117,14 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
accelerator.wait_for_everyone()
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0 train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加 unet.requires_grad_(True) # 念のため追加

View File

@@ -172,12 +172,14 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
accelerator.wait_for_everyone()
# prepare network # prepare network
import sys import sys

View File

@@ -233,12 +233,14 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()

View File

@@ -267,12 +267,14 @@ def train(args):
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
with torch.no_grad(): with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size) train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu") vae.to("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
accelerator.wait_for_everyone()
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()