WIP: update new latents caching

This commit is contained in:
Kohya S
2024-07-09 23:15:38 +09:00
parent 9dc7997803
commit 3d402927ef
3 changed files with 77 additions and 26 deletions

View File

@@ -1,4 +1,5 @@
import argparse import argparse
import glob
import math import math
import os import os
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
@@ -282,12 +283,26 @@ def sample_images(*args, **kwargs):
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy): class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz" SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check) super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.vae = None
def set_vae(self, vae: sd3_models.SDVAE):
self.vae = vae self.vae = vae
def get_latents_npz_path(self, absolute_path: str): def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
if not self.cache_to_disk: if not self.cache_to_disk:
@@ -331,24 +346,24 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype) img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
with torch.no_grad(): with torch.no_grad():
latents = self.vae.encode(img_tensor).to("cpu") latents_tensors = self.vae.encode(img_tensor).to("cpu")
if flip_aug: if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3]) img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad(): with torch.no_grad():
flipped_latents = self.vae.encode(img_tensor).to("cpu") flipped_latents = self.vae.encode(img_tensor).to("cpu")
else: else:
flipped_latents = [None] * len(latents) flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
if self.cache_to_disk: if self.cache_to_disk:
# save_latents_to_disk(
# info.latents_npz,
# latent,
# info.latents_original_size,
# info.latents_crop_ltrb,
# flipped_latent,
# alpha_mask,
# )
kwargs = {} kwargs = {}
if flipped_latent is not None: if flipped_latent is not None:
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy() kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
@@ -357,12 +372,12 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
np.savez( np.savez(
info.latents_npz, info.latents_npz,
latents=latents.float().cpu().numpy(), latents=latents.float().cpu().numpy(),
original_size=np.array(original_sizes), original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrbs), crop_ltrb=np.array(crop_ltrb),
**kwargs, **kwargs,
) )
else: else:
info.latents = latent info.latents = latents
if flip_aug: if flip_aug:
info.latents_flipped = flipped_latent info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask info.alpha_mask = alpha_mask

View File

@@ -360,11 +360,23 @@ class AugHelper:
class LatentsCachingStrategy: class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None: def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk self._cache_to_disk = cache_to_disk
self._batch_size = batch_size self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property @property
def cache_to_disk(self): def cache_to_disk(self):
return self._cache_to_disk return self._cache_to_disk
@@ -373,10 +385,15 @@ class LatentsCachingStrategy:
def batch_size(self): def batch_size(self):
return self._batch_size return self._batch_size
def get_latents_npz_path(self, absolute_path: str): def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
raise NotImplementedError raise NotImplementedError
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError raise NotImplementedError
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool): def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -1034,7 +1051,7 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents # check disk cache exists and size of latents
if caching_strategy.cache_to_disk: if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix # info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path) info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
if not is_main_process: # prepare for multi-gpu, only store to info if not is_main_process: # prepare for multi-gpu, only store to info
continue continue
@@ -1730,6 +1747,18 @@ class DreamBoothDataset(BaseDataset):
img_paths = glob_images(subset.image_dir, "*") img_paths = glob_images(subset.image_dir, "*")
sizes = [None] * len(img_paths) sizes = [None] * len(img_paths)
# new caching: get image size from cache files
strategy = LatentsCachingStrategy.get_strategy()
if strategy is not None:
logger.info("get image size from cache files")
size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)):
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
if w is not None and h is not None:
sizes[i] = [w, h]
size_set_count += 1
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset: if use_cached_info_for_subset:
@@ -2807,12 +2836,12 @@ def cache_batch_text_encoder_outputs_sd3(
b_lg_out = b_lg_out.detach() b_lg_out = b_lg_out.detach()
b_t5_out = b_t5_out.detach() b_t5_out = b_t5_out.detach()
b_pool = b_pool.detach() b_pool = b_pool.detach()
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
# debug: NaN check # debug: NaN check
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any(): if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}") raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
if cache_to_disk: if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
else: else:

View File

@@ -91,6 +91,15 @@ def train(args):
# load tokenizer # load tokenizer
sd3_tokenizer = sd3_models.SD3Tokenizer() sd3_tokenizer = sd3_models.SD3Tokenizer()
# prepare caching strategy
if args.new_caching:
latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
)
else:
latents_caching_strategy = None
train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
# データセットを準備する # データセットを準備する
if args.dataset_class is None: if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True)) blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
@@ -217,10 +226,8 @@ def train(args):
file_suffix="_sd3.npz", file_suffix="_sd3.npz",
) )
else: else:
strategy = sd3_train_utils.Sd3LatentsCachingStrategy( latents_caching_strategy.set_vae(vae)
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy)
)
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
vae.to("cpu") # if no sampling, vae can be deleted vae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)