mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
WIP: new latents caching
This commit is contained in:
@@ -359,6 +359,30 @@ class AugHelper:
|
||||
return self.color_aug if use_color_aug else None
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str):
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseSubset:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -986,6 +1010,69 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
|
||||
r"""
|
||||
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||
"""
|
||||
logger.info("caching latents with caching strategy.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path)
|
||||
if not is_main_process: # prepare for multi-gpu, only store to info
|
||||
continue
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# if cache to disk, don't cache latents in non-main process, set to info only
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
return
|
||||
|
||||
if len(batches) == 0:
|
||||
logger.info("no latents to cache")
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
@@ -1086,7 +1173,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
@@ -2207,6 +2294,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||
|
||||
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(is_main_process, strategy)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
@@ -2550,6 +2642,51 @@ def trim_and_resize_if_required(
|
||||
return image, original_size, crop_ltrb
|
||||
|
||||
|
||||
# for new_cache_latents
|
||||
def load_images_and_masks_for_caching(
|
||||
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
r"""
|
||||
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||
|
||||
returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||
|
||||
image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1]
|
||||
alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1]
|
||||
original_sizes: List[Tuple[int, int]] = [(W, H), ...]
|
||||
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
|
||||
"""
|
||||
images: List[torch.Tensor] = []
|
||||
alpha_masks: List[np.ndarray] = []
|
||||
original_sizes: List[Tuple[int, int]] = []
|
||||
crop_ltrbs: List[Tuple[int, int, int, int]] = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
|
||||
original_sizes.append(original_size)
|
||||
crop_ltrbs.append(crop_ltrb)
|
||||
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||
else:
|
||||
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
|
||||
else:
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
img_tensor = torch.stack(images, dim=0)
|
||||
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
|
||||
) -> None:
|
||||
@@ -2661,7 +2798,7 @@ def cache_batch_text_encoder_outputs_sd3(
|
||||
):
|
||||
# make input_ids for each text encoder
|
||||
l_tokens, g_tokens, t5_tokens = input_ids
|
||||
|
||||
|
||||
clip_l, clip_g, t5xxl = text_encoders
|
||||
with torch.no_grad():
|
||||
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
|
||||
@@ -2670,8 +2807,12 @@ def cache_batch_text_encoder_outputs_sd3(
|
||||
b_lg_out = b_lg_out.detach()
|
||||
b_t5_out = b_t5_out.detach()
|
||||
b_pool = b_pool.detach()
|
||||
|
||||
|
||||
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
|
||||
# debug: NaN check
|
||||
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
|
||||
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user