mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
WIP: new latents caching
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
@@ -283,6 +283,98 @@ def sample_images(*args, **kwargs):
|
|||||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy):
|
||||||
|
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||||
|
|
||||||
|
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||||
|
self.vae = vae
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str):
|
||||||
|
return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
|
if not self.cache_to_disk:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
return False
|
||||||
|
if self.skip_disk_cache_validity_check:
|
||||||
|
return True
|
||||||
|
|
||||||
|
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
|
||||||
|
|
||||||
|
try:
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if flip_aug:
|
||||||
|
if "latents_flipped" not in npz:
|
||||||
|
return False
|
||||||
|
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if alpha_mask:
|
||||||
|
if "alpha_mask" not in npz:
|
||||||
|
return False
|
||||||
|
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if "alpha_mask" in npz:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||||
|
image_infos, alpha_mask, random_crop
|
||||||
|
)
|
||||||
|
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = self.vae.encode(img_tensor).to("cpu")
|
||||||
|
if flip_aug:
|
||||||
|
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||||
|
with torch.no_grad():
|
||||||
|
flipped_latents = self.vae.encode(img_tensor).to("cpu")
|
||||||
|
else:
|
||||||
|
flipped_latents = [None] * len(latents)
|
||||||
|
|
||||||
|
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
|
||||||
|
if self.cache_to_disk:
|
||||||
|
# save_latents_to_disk(
|
||||||
|
# info.latents_npz,
|
||||||
|
# latent,
|
||||||
|
# info.latents_original_size,
|
||||||
|
# info.latents_crop_ltrb,
|
||||||
|
# flipped_latent,
|
||||||
|
# alpha_mask,
|
||||||
|
# )
|
||||||
|
kwargs = {}
|
||||||
|
if flipped_latent is not None:
|
||||||
|
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
|
||||||
|
if alpha_mask is not None:
|
||||||
|
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||||
|
np.savez(
|
||||||
|
info.latents_npz,
|
||||||
|
latents=latents.float().cpu().numpy(),
|
||||||
|
original_size=np.array(original_sizes),
|
||||||
|
crop_ltrb=np.array(crop_ltrbs),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
info.latents = latent
|
||||||
|
if flip_aug:
|
||||||
|
info.latents_flipped = flipped_latent
|
||||||
|
info.alpha_mask = alpha_mask
|
||||||
|
|
||||||
|
if not train_util.HIGH_VRAM:
|
||||||
|
clean_memory_on_device(self.vae.device)
|
||||||
|
|
||||||
|
|
||||||
# region Diffusers
|
# region Diffusers
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -359,6 +359,30 @@ class AugHelper:
|
|||||||
return self.color_aug if use_color_aug else None
|
return self.color_aug if use_color_aug else None
|
||||||
|
|
||||||
|
|
||||||
|
class LatentsCachingStrategy:
|
||||||
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
|
self._cache_to_disk = cache_to_disk
|
||||||
|
self._batch_size = batch_size
|
||||||
|
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cache_to_disk(self):
|
||||||
|
return self._cache_to_disk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def batch_size(self):
|
||||||
|
return self._batch_size
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class BaseSubset:
|
class BaseSubset:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -986,6 +1010,69 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
|
||||||
|
r"""
|
||||||
|
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||||
|
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||||
|
"""
|
||||||
|
logger.info("caching latents with caching strategy.")
|
||||||
|
image_infos = list(self.image_data.values())
|
||||||
|
|
||||||
|
# sort by resolution
|
||||||
|
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||||
|
|
||||||
|
# split by resolution
|
||||||
|
batches = []
|
||||||
|
batch = []
|
||||||
|
logger.info("checking cache validity...")
|
||||||
|
for info in tqdm(image_infos):
|
||||||
|
subset = self.image_to_subset[info.image_key]
|
||||||
|
|
||||||
|
if info.latents_npz is not None: # fine tuning dataset
|
||||||
|
continue
|
||||||
|
|
||||||
|
# check disk cache exists and size of latents
|
||||||
|
if caching_strategy.cache_to_disk:
|
||||||
|
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||||
|
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path)
|
||||||
|
if not is_main_process: # prepare for multi-gpu, only store to info
|
||||||
|
continue
|
||||||
|
|
||||||
|
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||||
|
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||||
|
)
|
||||||
|
if cache_available: # do not add to batch
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if last member of batch has different resolution, flush the batch
|
||||||
|
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||||
|
batches.append(batch)
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
batch.append(info)
|
||||||
|
|
||||||
|
# if number of data in batch is enough, flush the batch
|
||||||
|
if len(batch) >= caching_strategy.batch_size:
|
||||||
|
batches.append(batch)
|
||||||
|
batch = []
|
||||||
|
|
||||||
|
if len(batch) > 0:
|
||||||
|
batches.append(batch)
|
||||||
|
|
||||||
|
# if cache to disk, don't cache latents in non-main process, set to info only
|
||||||
|
if caching_strategy.cache_to_disk and not is_main_process:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(batches) == 0:
|
||||||
|
logger.info("no latents to cache")
|
||||||
|
return
|
||||||
|
|
||||||
|
# iterate batches: batch doesn't have image here. image will be loaded in cache_batch_latents and discarded
|
||||||
|
logger.info("caching latents...")
|
||||||
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
|
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
|
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||||
logger.info("caching latents.")
|
logger.info("caching latents.")
|
||||||
@@ -1086,7 +1173,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size = self.batch_size
|
batch_size = self.batch_size
|
||||||
|
|
||||||
image_infos = list(self.image_data.values())
|
image_infos = list(self.image_data.values())
|
||||||
|
|
||||||
logger.info("checking cache existence...")
|
logger.info("checking cache existence...")
|
||||||
@@ -2207,6 +2294,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
logger.info(f"[Dataset {i}]")
|
logger.info(f"[Dataset {i}]")
|
||||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||||
|
|
||||||
|
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
|
||||||
|
for i, dataset in enumerate(self.datasets):
|
||||||
|
logger.info(f"[Dataset {i}]")
|
||||||
|
dataset.new_cache_latents(is_main_process, strategy)
|
||||||
|
|
||||||
def cache_text_encoder_outputs(
|
def cache_text_encoder_outputs(
|
||||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||||
):
|
):
|
||||||
@@ -2550,6 +2642,51 @@ def trim_and_resize_if_required(
|
|||||||
return image, original_size, crop_ltrb
|
return image, original_size, crop_ltrb
|
||||||
|
|
||||||
|
|
||||||
|
# for new_cache_latents
|
||||||
|
def load_images_and_masks_for_caching(
|
||||||
|
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||||
|
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||||
|
r"""
|
||||||
|
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||||
|
|
||||||
|
returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||||
|
|
||||||
|
image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1]
|
||||||
|
alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1]
|
||||||
|
original_sizes: List[Tuple[int, int]] = [(W, H), ...]
|
||||||
|
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
|
||||||
|
"""
|
||||||
|
images: List[torch.Tensor] = []
|
||||||
|
alpha_masks: List[np.ndarray] = []
|
||||||
|
original_sizes: List[Tuple[int, int]] = []
|
||||||
|
crop_ltrbs: List[Tuple[int, int, int, int]] = []
|
||||||
|
for info in image_infos:
|
||||||
|
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||||
|
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||||
|
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||||
|
|
||||||
|
original_sizes.append(original_size)
|
||||||
|
crop_ltrbs.append(crop_ltrb)
|
||||||
|
|
||||||
|
if use_alpha_mask:
|
||||||
|
if image.shape[2] == 4:
|
||||||
|
alpha_mask = image[:, :, 3] # [H,W]
|
||||||
|
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||||
|
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||||
|
else:
|
||||||
|
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
|
||||||
|
else:
|
||||||
|
alpha_mask = None
|
||||||
|
alpha_masks.append(alpha_mask)
|
||||||
|
|
||||||
|
image = image[:, :, :3] # remove alpha channel if exists
|
||||||
|
image = IMAGE_TRANSFORMS(image)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
img_tensor = torch.stack(images, dim=0)
|
||||||
|
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||||
|
|
||||||
|
|
||||||
def cache_batch_latents(
|
def cache_batch_latents(
|
||||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
|
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
@@ -2661,7 +2798,7 @@ def cache_batch_text_encoder_outputs_sd3(
|
|||||||
):
|
):
|
||||||
# make input_ids for each text encoder
|
# make input_ids for each text encoder
|
||||||
l_tokens, g_tokens, t5_tokens = input_ids
|
l_tokens, g_tokens, t5_tokens = input_ids
|
||||||
|
|
||||||
clip_l, clip_g, t5xxl = text_encoders
|
clip_l, clip_g, t5xxl = text_encoders
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
|
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
|
||||||
@@ -2670,8 +2807,12 @@ def cache_batch_text_encoder_outputs_sd3(
|
|||||||
b_lg_out = b_lg_out.detach()
|
b_lg_out = b_lg_out.detach()
|
||||||
b_t5_out = b_t5_out.detach()
|
b_t5_out = b_t5_out.detach()
|
||||||
b_pool = b_pool.detach()
|
b_pool = b_pool.detach()
|
||||||
|
|
||||||
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
|
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
|
||||||
|
# debug: NaN check
|
||||||
|
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
|
||||||
|
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
|
||||||
|
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
|
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
|
||||||
else:
|
else:
|
||||||
|
|||||||
37
sd3_train.py
37
sd3_train.py
@@ -204,11 +204,22 @@ def train(args):
|
|||||||
vae.to(accelerator.device, dtype=vae_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
|
|
||||||
with torch.no_grad():
|
if not args.new_caching:
|
||||||
train_dataset_group.cache_latents(
|
vae_wrapper = sd3_models.VAEWrapper(vae) # make SD/SDXL compatible
|
||||||
vae_wrapper, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process, file_suffix="_sd3.npz"
|
with torch.no_grad():
|
||||||
|
train_dataset_group.cache_latents(
|
||||||
|
vae_wrapper,
|
||||||
|
args.vae_batch_size,
|
||||||
|
args.cache_latents_to_disk,
|
||||||
|
accelerator.is_main_process,
|
||||||
|
file_suffix="_sd3.npz",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
strategy = sd3_train_utils.Sd3LatensCachingStrategy(
|
||||||
|
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||||
)
|
)
|
||||||
|
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -699,6 +710,17 @@ def train(args):
|
|||||||
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
||||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||||
|
|
||||||
|
# debug: NaN check for all inputs
|
||||||
|
if torch.any(torch.isnan(noisy_model_input)):
|
||||||
|
accelerator.print("NaN found in noisy_model_input, replacing with zeros")
|
||||||
|
noisy_model_input = torch.nan_to_num(noisy_model_input, 0, out=noisy_model_input)
|
||||||
|
if torch.any(torch.isnan(context)):
|
||||||
|
accelerator.print("NaN found in context, replacing with zeros")
|
||||||
|
context = torch.nan_to_num(context, 0, out=context)
|
||||||
|
if torch.any(torch.isnan(pool)):
|
||||||
|
accelerator.print("NaN found in pool, replacing with zeros")
|
||||||
|
pool = torch.nan_to_num(pool, 0, out=pool)
|
||||||
|
|
||||||
# call model
|
# call model
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool)
|
model_pred = mmdit(noisy_model_input, timesteps, context=context, y=pool)
|
||||||
@@ -908,6 +930,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--new_caching", action="store_true", help="use new caching method / 新しいキャッシング方法を使う")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_latents_validity_check",
|
||||||
|
action="store_true",
|
||||||
|
help="skip latents validity check / latentsの正当性チェックをスキップする",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user