mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
329 lines
13 KiB
Python
329 lines
13 KiB
Python
# base class for platform strategies. this file defines the interface for strategies
|
||
|
||
import os
|
||
from typing import Any, List, Optional, Tuple, Union
|
||
|
||
import numpy as np
|
||
import torch
|
||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||
|
||
|
||
# TODO remove circular import by moving ImageInfo to a separate file
|
||
# from library.train_util import ImageInfo
|
||
|
||
from library.utils import setup_logging
|
||
|
||
setup_logging()
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TokenizeStrategy:
|
||
_strategy = None # strategy instance: actual strategy class
|
||
|
||
@classmethod
|
||
def set_strategy(cls, strategy):
|
||
if cls._strategy is not None:
|
||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||
cls._strategy = strategy
|
||
|
||
@classmethod
|
||
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
|
||
return cls._strategy
|
||
|
||
def _load_tokenizer(
|
||
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
||
) -> Any:
|
||
tokenizer = None
|
||
if tokenizer_cache_dir:
|
||
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
||
if os.path.exists(local_tokenizer_path):
|
||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||
|
||
if tokenizer is None:
|
||
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
|
||
|
||
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||
tokenizer.save_pretrained(local_tokenizer_path)
|
||
|
||
return tokenizer
|
||
|
||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||
raise NotImplementedError
|
||
|
||
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
|
||
"""
|
||
for SD1.5/2.0/SDXL
|
||
TODO support batch input
|
||
"""
|
||
if max_length is None:
|
||
max_length = tokenizer.model_max_length - 2
|
||
|
||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||
|
||
if max_length > tokenizer.model_max_length:
|
||
input_ids = input_ids.squeeze(0)
|
||
iids_list = []
|
||
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||
# v1
|
||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
|
||
ids_chunk = (
|
||
input_ids[0].unsqueeze(0),
|
||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||
input_ids[-1].unsqueeze(0),
|
||
)
|
||
ids_chunk = torch.cat(ids_chunk)
|
||
iids_list.append(ids_chunk)
|
||
else:
|
||
# v2 or SDXL
|
||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||
ids_chunk = (
|
||
input_ids[0].unsqueeze(0), # BOS
|
||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||
input_ids[-1].unsqueeze(0),
|
||
) # PAD or EOS
|
||
ids_chunk = torch.cat(ids_chunk)
|
||
|
||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||
ids_chunk[-1] = tokenizer.eos_token_id
|
||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||
if ids_chunk[1] == tokenizer.pad_token_id:
|
||
ids_chunk[1] = tokenizer.eos_token_id
|
||
|
||
iids_list.append(ids_chunk)
|
||
|
||
input_ids = torch.stack(iids_list) # 3,77
|
||
return input_ids
|
||
|
||
|
||
class TextEncodingStrategy:
|
||
_strategy = None # strategy instance: actual strategy class
|
||
|
||
@classmethod
|
||
def set_strategy(cls, strategy):
|
||
if cls._strategy is not None:
|
||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||
cls._strategy = strategy
|
||
|
||
@classmethod
|
||
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
|
||
return cls._strategy
|
||
|
||
def encode_tokens(
|
||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||
) -> List[torch.Tensor]:
|
||
"""
|
||
Encode tokens into embeddings and outputs.
|
||
:param tokens: list of token tensors for each TextModel
|
||
:return: list of output embeddings for each architecture
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
|
||
class TextEncoderOutputsCachingStrategy:
|
||
_strategy = None # strategy instance: actual strategy class
|
||
|
||
def __init__(
|
||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||
) -> None:
|
||
self._cache_to_disk = cache_to_disk
|
||
self._batch_size = batch_size
|
||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||
self._is_partial = is_partial
|
||
|
||
@classmethod
|
||
def set_strategy(cls, strategy):
|
||
if cls._strategy is not None:
|
||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||
cls._strategy = strategy
|
||
|
||
@classmethod
|
||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||
return cls._strategy
|
||
|
||
@property
|
||
def cache_to_disk(self):
|
||
return self._cache_to_disk
|
||
|
||
@property
|
||
def batch_size(self):
|
||
return self._batch_size
|
||
|
||
@property
|
||
def is_partial(self):
|
||
return self._is_partial
|
||
|
||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||
raise NotImplementedError
|
||
|
||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||
raise NotImplementedError
|
||
|
||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||
raise NotImplementedError
|
||
|
||
def cache_batch_outputs(
|
||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||
):
|
||
raise NotImplementedError
|
||
|
||
|
||
class LatentsCachingStrategy:
|
||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||
|
||
_strategy = None # strategy instance: actual strategy class
|
||
|
||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||
self._cache_to_disk = cache_to_disk
|
||
self._batch_size = batch_size
|
||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||
|
||
@classmethod
|
||
def set_strategy(cls, strategy):
|
||
if cls._strategy is not None:
|
||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||
cls._strategy = strategy
|
||
|
||
@classmethod
|
||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||
return cls._strategy
|
||
|
||
@property
|
||
def cache_to_disk(self):
|
||
return self._cache_to_disk
|
||
|
||
@property
|
||
def batch_size(self):
|
||
return self._batch_size
|
||
|
||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||
raise NotImplementedError
|
||
|
||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||
raise NotImplementedError
|
||
|
||
def is_disk_cached_latents_expected(
|
||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||
) -> bool:
|
||
raise NotImplementedError
|
||
|
||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||
raise NotImplementedError
|
||
|
||
def _default_is_disk_cached_latents_expected(
|
||
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||
):
|
||
if not self.cache_to_disk:
|
||
return False
|
||
if not os.path.exists(npz_path):
|
||
return False
|
||
if self.skip_disk_cache_validity_check:
|
||
return True
|
||
|
||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||
|
||
try:
|
||
npz = np.load(npz_path)
|
||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||
return False
|
||
|
||
if flip_aug:
|
||
if "latents_flipped" not in npz:
|
||
return False
|
||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||
return False
|
||
|
||
if alpha_mask:
|
||
if "alpha_mask" not in npz:
|
||
return False
|
||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||
return False
|
||
else:
|
||
if "alpha_mask" in npz:
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Error loading file: {npz_path}")
|
||
raise e
|
||
|
||
return True
|
||
|
||
# TODO remove circular dependency for ImageInfo
|
||
def _default_cache_batch_latents(
|
||
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
||
):
|
||
"""
|
||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||
"""
|
||
from library import train_util # import here to avoid circular import
|
||
|
||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||
image_infos, alpha_mask, random_crop
|
||
)
|
||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||
|
||
with torch.no_grad():
|
||
latents_tensors = encode_by_vae(img_tensor).to("cpu")
|
||
if flip_aug:
|
||
img_tensor = torch.flip(img_tensor, dims=[3])
|
||
with torch.no_grad():
|
||
flipped_latents = encode_by_vae(img_tensor).to("cpu")
|
||
else:
|
||
flipped_latents = [None] * len(latents_tensors)
|
||
|
||
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||
for i in range(len(image_infos)):
|
||
info = image_infos[i]
|
||
latents = latents_tensors[i]
|
||
flipped_latent = flipped_latents[i]
|
||
alpha_mask = alpha_masks[i]
|
||
original_size = original_sizes[i]
|
||
crop_ltrb = crop_ltrbs[i]
|
||
|
||
if self.cache_to_disk:
|
||
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||
else:
|
||
info.latents_original_size = original_size
|
||
info.latents_crop_ltrb = crop_ltrb
|
||
info.latents = latents
|
||
if flip_aug:
|
||
info.latents_flipped = flipped_latent
|
||
info.alpha_mask = alpha_mask
|
||
|
||
def load_latents_from_disk(
|
||
self, npz_path: str
|
||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||
npz = np.load(npz_path)
|
||
if "latents" not in npz:
|
||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||
|
||
latents = npz["latents"]
|
||
original_size = npz["original_size"].tolist()
|
||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||
|
||
def save_latents_to_disk(
|
||
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
|
||
):
|
||
kwargs = {}
|
||
if flipped_latents_tensor is not None:
|
||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||
if alpha_mask is not None:
|
||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||
np.savez(
|
||
npz_path,
|
||
latents=latents_tensor.float().cpu().numpy(),
|
||
original_size=np.array(original_size),
|
||
crop_ltrb=np.array(crop_ltrb),
|
||
**kwargs,
|
||
)
|