mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
328
library/strategy_base.py
Normal file
328
library/strategy_base.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# base class for platform strategies. this file defines the interface for strategies
|
||||
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
def _load_tokenizer(
|
||||
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
||||
) -> Any:
|
||||
tokenizer = None
|
||||
if tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
for SD1.5/2.0/SDXL
|
||||
TODO support batch input
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length - 2
|
||||
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
|
||||
if max_length > tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
iids_list = []
|
||||
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||||
# v1
|
||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0),
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
)
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
iids_list.append(ids_chunk)
|
||||
else:
|
||||
# v2 or SDXL
|
||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0), # BOS
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
) # PAD or EOS
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
|
||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = tokenizer.eos_token_id
|
||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||
if ids_chunk[1] == tokenizer.pad_token_id:
|
||||
ids_chunk[1] = tokenizer.eos_token_id
|
||||
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
return input_ids
|
||||
|
||||
|
||||
class TextEncodingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Encode tokens into embeddings and outputs.
|
||||
:param tokens: list of token tensors for each TextModel
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextEncoderOutputsCachingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||
) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
self._is_partial = is_partial
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def is_partial(self):
|
||||
return self._is_partial
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def _defualt_is_disk_cached_latents_expected(
|
||||
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents_tensors = encode_by_vae(img_tensor).to("cpu")
|
||||
if flip_aug:
|
||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = encode_by_vae(img_tensor).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents_tensors)
|
||||
|
||||
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||||
for i in range(len(image_infos)):
|
||||
info = image_infos[i]
|
||||
latents = latents_tensors[i]
|
||||
flipped_latent = flipped_latents[i]
|
||||
alpha_mask = alpha_masks[i]
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||||
else:
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
info.latents = latents
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
|
||||
):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_ltrb=np.array(crop_ltrb),
|
||||
**kwargs,
|
||||
)
|
||||
Reference in New Issue
Block a user