mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
WIP: new latents caching
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
@@ -283,6 +283,98 @@ def sample_images(*args, **kwargs):
|
||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
|
||||
class Sd3LatensCachingStrategy(train_util.LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
|
||||
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.vae = vae
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str):
|
||||
return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = self.vae.encode(img_tensor).to("cpu")
|
||||
if flip_aug:
|
||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = self.vae.encode(img_tensor).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents)
|
||||
|
||||
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
|
||||
if self.cache_to_disk:
|
||||
# save_latents_to_disk(
|
||||
# info.latents_npz,
|
||||
# latent,
|
||||
# info.latents_original_size,
|
||||
# info.latents_crop_ltrb,
|
||||
# flipped_latent,
|
||||
# alpha_mask,
|
||||
# )
|
||||
kwargs = {}
|
||||
if flipped_latent is not None:
|
||||
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
info.latents_npz,
|
||||
latents=latents.float().cpu().numpy(),
|
||||
original_size=np.array(original_sizes),
|
||||
crop_ltrb=np.array(crop_ltrbs),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
clean_memory_on_device(self.vae.device)
|
||||
|
||||
|
||||
# region Diffusers
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user