WIP: new latents caching

This commit is contained in:
Kohya S
2024-07-08 19:48:28 +09:00
parent 50e3d62474
commit c9de7c4e9a
3 changed files with 270 additions and 8 deletions

View File

@@ -1,7 +1,7 @@
import argparse import argparse
import math import math
import os import os
from typing import Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from safetensors.torch import save_file from safetensors.torch import save_file
@@ -283,6 +283,98 @@ def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs) return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.vae = vae
def get_latents_npz_path(self, absolute_path: str):
return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
with torch.no_grad():
latents = self.vae.encode(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = self.vae.encode(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
if self.cache_to_disk:
# save_latents_to_disk(
# info.latents_npz,
# latent,
# info.latents_original_size,
# info.latents_crop_ltrb,
# flipped_latent,
# alpha_mask,
# )
kwargs = {}
if flipped_latent is not None:
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
info.latents_npz,
latents=latents.float().cpu().numpy(),
original_size=np.array(original_sizes),
crop_ltrb=np.array(crop_ltrbs),
**kwargs,
)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
if not train_util.HIGH_VRAM:
clean_memory_on_device(self.vae.device)
# region Diffusers # region Diffusers

View File

@@ -359,6 +359,30 @@ class AugHelper:
return self.color_aug if use_color_aug else None return self.color_aug if use_color_aug else None
class LatentsCachingStrategy:
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
def get_latents_npz_path(self, absolute_path: str):
raise NotImplementedError
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
raise NotImplementedError
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
class BaseSubset: class BaseSubset:
def __init__( def __init__(
self, self,
@@ -986,6 +1010,69 @@ class BaseDataset(torch.utils.data.Dataset):
] ]
) )
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
r"""
a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
"""
logger.info("caching latents with caching strategy.")
image_infos = list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None: # fine tuning dataset
continue
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path)
if not is_main_process: # prepare for multi-gpu, only store to info
continue
cache_available = caching_strategy.is_disk_cached_latents_expected(
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
)
if cache_available: # do not add to batch
continue
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= caching_strategy.batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# if cache to disk, don't cache latents in non-main process, set to info only
if caching_strategy.cache_to_disk and not is_main_process:
return
if len(batches) == 0:
logger.info("no latents to cache")
return
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと # マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching latents.") logger.info("caching latents.")
@@ -2207,6 +2294,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info(f"[Dataset {i}]") logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix) dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(is_main_process, strategy)
def cache_text_encoder_outputs( def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
): ):
@@ -2550,6 +2642,51 @@ def trim_and_resize_if_required(
return image, original_size, crop_ltrb return image, original_size, crop_ltrb
# for new_cache_latents
def load_images_and_masks_for_caching(
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
r"""
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs
image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1]
alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1]
original_sizes: List[Tuple[int, int]] = [(W, H), ...]
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
"""
images: List[torch.Tensor] = []
alpha_masks: List[np.ndarray] = []
original_sizes: List[Tuple[int, int]] = []
crop_ltrbs: List[Tuple[int, int, int, int]] = []
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
if use_alpha_mask:
if image.shape[2] == 4:
alpha_mask = image[:, :, 3] # [H,W]
alpha_mask = alpha_mask.astype(np.float32) / 255.0
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
else:
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
else:
alpha_mask = None
alpha_masks.append(alpha_mask)
image = image[:, :, :3] # remove alpha channel if exists
image = IMAGE_TRANSFORMS(image)
images.append(image)
img_tensor = torch.stack(images, dim=0)
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
def cache_batch_latents( def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
) -> None: ) -> None:
@@ -2672,6 +2809,10 @@ def cache_batch_text_encoder_outputs_sd3(
b_pool = b_pool.detach() b_pool = b_pool.detach()
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool): for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
# debug: NaN check
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
if cache_to_disk: if cache_to_disk:
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool) save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
else: else:

View File

@@ -204,11 +204,22 @@ def train(args):
vae.to(accelerator.device, dtype=vae_dtype) vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
with torch.no_grad(): if not args.new_caching:
train_dataset_group.cache_latents( vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz" with torch.no_grad():
train_dataset_group.cache_latents(
vae_wrapper,
args.vae_batch_size,
args.cache_latents_to_disk,
accelerator.is_main_process,
file_suffix="_sd3.npz",
)
else:
strategy = sd3_train_utils.Sd3LatensCachingStrategy(
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
) )
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
vae.to("cpu") vae.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
@@ -699,6 +710,17 @@ def train(args):
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype) sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# debug: NaN check for all inputs
if torch.any(torch.isnan(noisy_model_input)):
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
if torch.any(torch.isnan(context)):
accelerator.print("NaN found in context, replacing with zeros")
context = torch.nan_to_num(context, 0, out=context)
if torch.any(torch.isnan(pool)):
accelerator.print("NaN found in pool, replacing with zeros")
pool = torch.nan_to_num(pool, 0, out=pool)
# call model # call model
with accelerator.autocast(): with accelerator.autocast():
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool) model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool)
@@ -908,6 +930,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
) )
parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う")
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="skip latents validity check / latentsの正当性チェックをスキップする",
)
return parser return parser