mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Refactor caching mechanism for latents and text encoder outputs, etc.
This commit is contained in:
@@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams):
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
|
||||
@@ -38,7 +38,7 @@ class SDTokenizer:
|
||||
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
|
||||
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
|
||||
"""
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer: CLIPTokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
empty = self.tokenizer("")["input_ids"]
|
||||
@@ -56,6 +56,19 @@ class SDTokenizer:
|
||||
self.inv_vocab = {v: k for k, v in vocab.items()}
|
||||
self.max_word_length = 8
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
"""
|
||||
Tokenize the text without weights.
|
||||
"""
|
||||
if type(text) == str:
|
||||
text = [text]
|
||||
batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
|
||||
# return tokens["input_ids"]
|
||||
|
||||
pad_token = self.end_token if self.pad_with_end else 0
|
||||
for tokens in batch_tokens["input_ids"]:
|
||||
assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}"
|
||||
|
||||
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
|
||||
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
|
||||
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
|
||||
@@ -75,13 +88,14 @@ class SDTokenizer:
|
||||
for word in to_tokenize:
|
||||
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
|
||||
batch.append((self.end_token, 1.0))
|
||||
print(len(batch), self.max_length, self.min_length)
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
|
||||
|
||||
# truncate to max_length
|
||||
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
|
||||
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
|
||||
if truncate_to_max_length and len(batch) > self.max_length:
|
||||
batch = batch[: self.max_length]
|
||||
if truncate_length is not None and len(batch) > truncate_length:
|
||||
@@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer):
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
def __init__(self, t5xxl=True):
|
||||
def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256):
|
||||
if t5xxl_max_length is None:
|
||||
t5xxl_max_length = 256
|
||||
|
||||
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
|
||||
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
|
||||
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
|
||||
# self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
# self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
self.t5xxl = T5XXLTokenizer() if t5xxl else None
|
||||
# t5xxl has 99999999 max length, clip has 77
|
||||
self.model_max_length = self.clip_l.max_length # 77
|
||||
self.t5xxl_max_length = t5xxl_max_length
|
||||
|
||||
def tokenize_with_weights(self, text: str):
|
||||
# temporary truncate to max_length even for t5xxl
|
||||
return (
|
||||
self.clip_l.tokenize_with_weights(text),
|
||||
self.clip_g.tokenize_with_weights(text),
|
||||
(
|
||||
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
|
||||
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length)
|
||||
if self.t5xxl is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
def tokenize(self, text: str):
|
||||
return (
|
||||
self.clip_l.tokenize(text),
|
||||
self.clip_g.tokenize(text),
|
||||
(self.t5xxl.tokenize(text) if self.t5xxl is not None else None),
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder:
|
||||
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
|
||||
list_of_tokens.append(tokens)
|
||||
else:
|
||||
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
|
||||
if isinstance(list_of_token_weight_pairs[0], torch.Tensor):
|
||||
list_of_tokens = [list(list_of_token_weight_pairs[0])]
|
||||
else:
|
||||
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
|
||||
|
||||
out, pooled = self(list_of_tokens)
|
||||
if has_batch:
|
||||
@@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel):
|
||||
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
|
||||
#################################################################################################
|
||||
|
||||
|
||||
"""
|
||||
class T5XXLTokenizer(SDTokenizer):
|
||||
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
|
||||
""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer):
|
||||
max_length=99999999,
|
||||
min_length=77,
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
class T5LayerNorm(torch.nn.Module):
|
||||
|
||||
@@ -280,111 +280,6 @@ def sample_images(*args, **kwargs):
|
||||
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.vae = None
|
||||
|
||||
def set_vae(self, vae: sd3_models.SDVAE):
|
||||
self.vae = vae
|
||||
|
||||
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
|
||||
if len(npz_file) == 0:
|
||||
return None, None
|
||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents_tensors = self.vae.encode(img_tensor).to("cpu")
|
||||
if flip_aug:
|
||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = self.vae.encode(img_tensor).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents_tensors)
|
||||
|
||||
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||||
for i in range(len(image_infos)):
|
||||
info = image_infos[i]
|
||||
latents = latents_tensors[i]
|
||||
flipped_latent = flipped_latents[i]
|
||||
alpha_mask = alpha_masks[i]
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
kwargs = {}
|
||||
if flipped_latent is not None:
|
||||
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
info.latents_npz,
|
||||
latents=latents.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_ltrb=np.array(crop_ltrb),
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
info.latents = latents
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
clean_memory_on_device(self.vae.device)
|
||||
|
||||
|
||||
# region Diffusers
|
||||
|
||||
|
||||
@@ -384,6 +384,7 @@ def get_cond(
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
|
||||
print(t5_tokens)
|
||||
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)
|
||||
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
)
|
||||
|
||||
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
||||
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
|
||||
parser.add_argument(
|
||||
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
||||
)
|
||||
|
||||
328
library/strategy_base.py
Normal file
328
library/strategy_base.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# base class for platform strategies. this file defines the interface for strategies
|
||||
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
def _load_tokenizer(
|
||||
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
|
||||
) -> Any:
|
||||
tokenizer = None
|
||||
if tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||
|
||||
if tokenizer is None:
|
||||
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
return tokenizer
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
for SD1.5/2.0/SDXL
|
||||
TODO support batch input
|
||||
"""
|
||||
if max_length is None:
|
||||
max_length = tokenizer.model_max_length - 2
|
||||
|
||||
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
|
||||
|
||||
if max_length > tokenizer.model_max_length:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
iids_list = []
|
||||
if tokenizer.pad_token_id == tokenizer.eos_token_id:
|
||||
# v1
|
||||
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
|
||||
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0),
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
)
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
iids_list.append(ids_chunk)
|
||||
else:
|
||||
# v2 or SDXL
|
||||
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
|
||||
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
|
||||
ids_chunk = (
|
||||
input_ids[0].unsqueeze(0), # BOS
|
||||
input_ids[i : i + tokenizer.model_max_length - 2],
|
||||
input_ids[-1].unsqueeze(0),
|
||||
) # PAD or EOS
|
||||
ids_chunk = torch.cat(ids_chunk)
|
||||
|
||||
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
|
||||
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変える(x <EOS> なら結果的に変化なし)
|
||||
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
|
||||
ids_chunk[-1] = tokenizer.eos_token_id
|
||||
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
|
||||
if ids_chunk[1] == tokenizer.pad_token_id:
|
||||
ids_chunk[1] = tokenizer.eos_token_id
|
||||
|
||||
iids_list.append(ids_chunk)
|
||||
|
||||
input_ids = torch.stack(iids_list) # 3,77
|
||||
return input_ids
|
||||
|
||||
|
||||
class TextEncodingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Encode tokens into embeddings and outputs.
|
||||
:param tokens: list of token tensors for each TextModel
|
||||
:return: list of output embeddings for each architecture
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TextEncoderOutputsCachingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||
) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
self._is_partial = is_partial
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def is_partial(self):
|
||||
return self._is_partial
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def _defualt_is_disk_cached_latents_expected(
|
||||
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
"""
|
||||
from library import train_util # import here to avoid circular import
|
||||
|
||||
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
|
||||
image_infos, alpha_mask, random_crop
|
||||
)
|
||||
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents_tensors = encode_by_vae(img_tensor).to("cpu")
|
||||
if flip_aug:
|
||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = encode_by_vae(img_tensor).to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents_tensors)
|
||||
|
||||
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||||
for i in range(len(image_infos)):
|
||||
info = image_infos[i]
|
||||
latents = latents_tensors[i]
|
||||
flipped_latent = flipped_latents[i]
|
||||
alpha_mask = alpha_masks[i]
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||||
else:
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
info.latents = latents
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
|
||||
):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_ltrb=np.array(crop_ltrb),
|
||||
**kwargs,
|
||||
)
|
||||
139
library/strategy_sd.py
Normal file
139
library/strategy_sd.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
|
||||
|
||||
class SdTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
|
||||
"""
|
||||
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
|
||||
if v2:
|
||||
self.tokenizer = self._load_tokenizer(
|
||||
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
|
||||
)
|
||||
else:
|
||||
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
|
||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||
self.clip_skip = clip_skip
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
text_encoder = models[0]
|
||||
tokens = tokens[0]
|
||||
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
||||
|
||||
# tokens: b,n,77
|
||||
b_size = tokens.size()[0]
|
||||
max_token_length = tokens.size()[1] * tokens.size()[2]
|
||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||
|
||||
if self.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(tokens)[0]
|
||||
else:
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||
|
||||
if max_token_length != model_max_length:
|
||||
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
||||
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||
|
||||
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.sd = sd
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
# does not include old npz
|
||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
|
||||
if len(npz_file) == 0:
|
||||
return None, None
|
||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
229
library/strategy_sd3.py
Normal file
229
library/strategy_sd3.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
|
||||
|
||||
class Sd3TokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.t5xxl_max_length = t5xxl_max_length
|
||||
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
|
||||
|
||||
l_tokens = l_tokens["input_ids"]
|
||||
g_tokens = g_tokens["input_ids"]
|
||||
t5_tokens = t5_tokens["input_ids"]
|
||||
|
||||
return [l_tokens, g_tokens, t5_tokens]
|
||||
|
||||
|
||||
class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
clip_l, clip_g, t5xxl = models
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = tokens
|
||||
if l_tokens is None:
|
||||
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
|
||||
lg_out = None
|
||||
else:
|
||||
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
|
||||
l_out, l_pooled = clip_l(l_tokens)
|
||||
g_out, g_pooled = clip_g(g_tokens)
|
||||
lg_out = torch.cat([l_out, g_out], dim=-1)
|
||||
|
||||
if t5xxl is not None and t5_tokens is not None:
|
||||
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
|
||||
else:
|
||||
t5_out = None
|
||||
|
||||
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
|
||||
return [lg_out, t5_out, lg_pooled]
|
||||
|
||||
def concat_encodings(
|
||||
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
|
||||
if t5_out is None:
|
||||
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
|
||||
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
|
||||
|
||||
|
||||
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, abs_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(self.get_outputs_npz_path(abs_path))
|
||||
if "clip_l" not in npz or "clip_g" not in npz:
|
||||
return False
|
||||
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
|
||||
return False
|
||||
# t5xxl is optional
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"] if "t5_out" in data else None
|
||||
return [lg_out, t5_out, lg_pooled]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out is not None and t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
if t5_out is not None:
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i] if t5_out is not None else None
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
kwargs = {}
|
||||
if t5_out is not None:
|
||||
kwargs["t5_out"] = t5_out_i
|
||||
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
|
||||
else:
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
|
||||
if len(npz_file) == 0:
|
||||
return None, None
|
||||
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test code for Sd3TokenizeStrategy
|
||||
# tokenizer = sd3_models.SD3Tokenizer()
|
||||
strategy = Sd3TokenizeStrategy(256)
|
||||
text = "hello world"
|
||||
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
# print(l_tokens.shape)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
|
||||
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
|
||||
t5_tokens_2 = strategy.t5xxl(
|
||||
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
print(l_tokens_2)
|
||||
print(g_tokens_2)
|
||||
print(t5_tokens_2)
|
||||
|
||||
# compare
|
||||
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
|
||||
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
|
||||
|
||||
text = ",".join(["hello world! this is long text"] * 50)
|
||||
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
|
||||
print(l_tokens)
|
||||
print(g_tokens)
|
||||
print(t5_tokens)
|
||||
|
||||
print(f"model max length l: {strategy.clip_l.model_max_length}")
|
||||
print(f"model max length g: {strategy.clip_g.model_max_length}")
|
||||
print(f"model max length t5: {strategy.t5xxl.model_max_length}")
|
||||
247
library/strategy_sdxl.py
Normal file
247
library/strategy_sdxl.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
|
||||
class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer1.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return (
|
||||
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
|
||||
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
|
||||
)
|
||||
|
||||
|
||||
class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _pool_workaround(
|
||||
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
|
||||
):
|
||||
r"""
|
||||
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
|
||||
instead of the hidden states for the EOS token
|
||||
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
|
||||
|
||||
Original code from CLIP's pooling function:
|
||||
|
||||
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
||||
\# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
||||
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
||||
]
|
||||
"""
|
||||
|
||||
# input_ids: b*n,77
|
||||
# find index for EOS token
|
||||
|
||||
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
|
||||
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
|
||||
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||
|
||||
# Create a mask where the EOS tokens are
|
||||
eos_token_mask = (input_ids == eos_token_id).int()
|
||||
|
||||
# Use argmax to find the last index of the EOS token for each element in the batch
|
||||
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
|
||||
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
|
||||
|
||||
# get hidden states for EOS token
|
||||
pooled_output = last_hidden_state[
|
||||
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
|
||||
]
|
||||
|
||||
# apply projection: projection may be of different dtype than last_hidden_state
|
||||
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
|
||||
pooled_output = pooled_output.to(last_hidden_state.dtype)
|
||||
|
||||
return pooled_output
|
||||
|
||||
def _get_hidden_states_sdxl(
|
||||
self,
|
||||
input_ids1: torch.Tensor,
|
||||
input_ids2: torch.Tensor,
|
||||
tokenizer1: CLIPTokenizer,
|
||||
tokenizer2: CLIPTokenizer,
|
||||
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
|
||||
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
|
||||
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
|
||||
):
|
||||
# input_ids: b,n,77 -> b*n, 77
|
||||
b_size = input_ids1.size()[0]
|
||||
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
|
||||
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
||||
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
||||
input_ids1 = input_ids1.to(text_encoder1.device)
|
||||
input_ids2 = input_ids2.to(text_encoder2.device)
|
||||
|
||||
# text_encoder1
|
||||
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
|
||||
hidden_states1 = enc_out["hidden_states"][11]
|
||||
|
||||
# text_encoder2
|
||||
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
|
||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||
|
||||
# pool2 = enc_out["text_embeds"]
|
||||
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
|
||||
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||
|
||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
||||
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
||||
|
||||
if max_token_length is not None:
|
||||
# bs*3, 77, 768 or 1024
|
||||
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer1.model_max_length):
|
||||
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
|
||||
hidden_states1 = torch.cat(states_list, dim=1)
|
||||
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, tokenizer2.model_max_length):
|
||||
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
# this causes an error:
|
||||
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
|
||||
# if i > 1:
|
||||
# for j in range(len(chunk)): # batch_size
|
||||
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
hidden_states2 = torch.cat(states_list, dim=1)
|
||||
|
||||
# pool はnの最初のものを使う
|
||||
pool2 = pool2[::n_size]
|
||||
|
||||
return hidden_states1, hidden_states2, pool2
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
tokenize_strategy: TokenizeStrategy
|
||||
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
|
||||
tokens: List of tokens, for text_encoder1 and text_encoder2
|
||||
"""
|
||||
if len(models) == 2:
|
||||
text_encoder1, text_encoder2 = models
|
||||
unwrapped_text_encoder2 = None
|
||||
else:
|
||||
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
|
||||
tokens1, tokens2 = tokens
|
||||
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
|
||||
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
|
||||
|
||||
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
|
||||
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
|
||||
)
|
||||
return [hidden_states1, hidden_states2, pool2]
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
def __init__(
|
||||
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
|
||||
def is_disk_cached_outputs_expected(self, abs_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(self.get_outputs_npz_path(abs_path))
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_state1 = data["hidden_state1"]
|
||||
hidden_state2 = data["hidden_state2"]
|
||||
pool2 = data["pool2"]
|
||||
return [hidden_state1, hidden_state2, pool2]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
):
|
||||
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||
captions = [info.caption for info in infos]
|
||||
|
||||
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, models, [tokens1, tokens2]
|
||||
)
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
@@ -12,6 +12,7 @@ import re
|
||||
import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
@@ -34,6 +35,7 @@ from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -81,10 +83,6 @@ logger = logging.getLogger(__name__)
|
||||
# from library.hypernetwork import replace_attentions_for_hypernetwork
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
|
||||
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
|
||||
HIGH_VRAM = False
|
||||
|
||||
# checkpointファイル名
|
||||
@@ -148,18 +146,24 @@ class ImageInfo:
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: torch.Tensor = None
|
||||
self.latents_flipped: torch.Tensor = None
|
||||
self.latents_npz: str = None
|
||||
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
|
||||
self.cond_img_path: str = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||
None # crop left top right bottom in original pixel size, not latents size
|
||||
)
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
# SDXL, optional
|
||||
self.text_encoder_outputs_npz: Optional[str] = None
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
# old
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
@@ -359,47 +363,6 @@ class AugHelper:
|
||||
return self.color_aug if use_color_aug else None
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||
cls._strategy = strategy
|
||||
|
||||
@classmethod
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseSubset:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset):
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
|
||||
max_token_length: int,
|
||||
resolution: Optional[Tuple[int, int]],
|
||||
network_multiplier: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
|
||||
|
||||
self.max_token_length = max_token_length
|
||||
# width/height is used when enable_bucket==False
|
||||
self.width, self.height = (None, None) if resolution is None else resolution
|
||||
self.network_multiplier = network_multiplier
|
||||
@@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_no_upscale = None
|
||||
self.bucket_info = None # for metadata
|
||||
|
||||
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
|
||||
self.current_step: int = 0
|
||||
@@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# caching
|
||||
self.caching_mode = None # None, 'latents', 'text'
|
||||
|
||||
self.tokenize_strategy = None
|
||||
self.text_encoder_output_caching_strategy = None
|
||||
self.latents_caching_strategy = None
|
||||
|
||||
def set_current_strategies(self):
|
||||
self.tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
@@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||
|
||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
||||
#
|
||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# # そのためバッチサイズを画像種類までに制限する
|
||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
||||
# num_of_image_types = len(set(bucket))
|
||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
# for batch_index in range(batch_count):
|
||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
# ↑ここまで
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
|
||||
@@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
r"""
|
||||
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||
"""
|
||||
logger.info("caching latents with caching strategy.")
|
||||
caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
@@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
logger.info("caching latents...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
@@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
r"""
|
||||
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
|
||||
"""
|
||||
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
text_encoding_strategy = TextEncodingStrategy.get_strategy()
|
||||
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
batch_size = caching_strategy.batch_size or self.batch_size
|
||||
|
||||
# if cache to disk, don't cache TE outputs in non-main process
|
||||
if caching_strategy.cache_to_disk and not is_main_process:
|
||||
return
|
||||
|
||||
logger.info("caching Text Encoder outputs with caching strategy.")
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
if len(batches) == 0:
|
||||
logger.info("no Text Encoder outputs to cache")
|
||||
return
|
||||
|
||||
# iterate batches
|
||||
logger.info("caching Text Encoder outputs...")
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
|
||||
|
||||
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
|
||||
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
|
||||
# to support SD1/2, it needs a flag for v2, but it is postponed
|
||||
@@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
|
||||
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
@@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
else:
|
||||
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
|
||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
@@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
loss_weights = []
|
||||
captions = []
|
||||
input_ids_list = []
|
||||
input_ids2_list = []
|
||||
latents_list = []
|
||||
alpha_mask_list = []
|
||||
images = []
|
||||
@@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
crop_top_lefts = []
|
||||
target_sizes_hw = []
|
||||
flippeds = [] # 変数名が微妙
|
||||
text_encoder_outputs1_list = []
|
||||
text_encoder_outputs2_list = []
|
||||
text_encoder_pool2_list = []
|
||||
text_encoder_outputs_list = []
|
||||
|
||||
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||
image_info = self.image_data[image_key]
|
||||
subset = self.image_to_subset[image_key]
|
||||
loss_weights.append(
|
||||
self.prior_loss_weight if image_info.is_reg else 1.0
|
||||
) # in case of fine tuning, is_reg is always False
|
||||
|
||||
# in case of fine tuning, is_reg is always False
|
||||
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
|
||||
|
||||
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
|
||||
|
||||
@@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
|
||||
)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
||||
@@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# captionとtext encoder outputを処理する
|
||||
caption = image_info.caption # default
|
||||
if image_info.text_encoder_outputs1 is not None:
|
||||
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
|
||||
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
|
||||
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
|
||||
tokenization_required = (
|
||||
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
|
||||
)
|
||||
text_encoder_outputs = None
|
||||
input_ids = None
|
||||
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
elif image_info.text_encoder_outputs_npz is not None:
|
||||
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
|
||||
# on disk
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||
image_info.text_encoder_outputs_npz
|
||||
)
|
||||
text_encoder_outputs1_list.append(text_encoder_outputs1)
|
||||
text_encoder_outputs2_list.append(text_encoder_outputs2)
|
||||
text_encoder_pool2_list.append(text_encoder_pool2)
|
||||
captions.append(caption)
|
||||
else:
|
||||
tokenization_required = True
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
# token_strings_from = " ".join(self.token_strings)
|
||||
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
# caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
# caption_layer.append(caption_)
|
||||
# captions.append(caption_layer)
|
||||
# else:
|
||||
# captions.append(caption)
|
||||
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# TODO get_input_ids must support SD3
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
input_ids_list.append(token_caption)
|
||||
# if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# # TODO get_input_ids must support SD3
|
||||
# if self.XTI_layers:
|
||||
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
# else:
|
||||
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
# input_ids_list.append(token_caption)
|
||||
|
||||
if len(self.tokenizers) > 1:
|
||||
if self.XTI_layers:
|
||||
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
else:
|
||||
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
input_ids2_list.append(token_caption2)
|
||||
# if len(self.tokenizers) > 1:
|
||||
# if self.XTI_layers:
|
||||
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
# else:
|
||||
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
# input_ids2_list.append(token_caption2)
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
captions.append(caption)
|
||||
|
||||
def none_or_stack_elements(tensors_list, converter):
|
||||
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
|
||||
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
|
||||
return None
|
||||
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if len(text_encoder_outputs1_list) == 0:
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
if len(self.tokenizers) > 1:
|
||||
example["input_ids2"] = self.tokenizer[1](
|
||||
captions, padding=True, truncation=True, return_tensors="pt"
|
||||
).input_ids
|
||||
else:
|
||||
example["input_ids2"] = None
|
||||
else:
|
||||
example["input_ids"] = torch.stack(input_ids_list)
|
||||
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
|
||||
example["text_encoder_outputs1_list"] = None
|
||||
example["text_encoder_outputs2_list"] = None
|
||||
example["text_encoder_pool2_list"] = None
|
||||
else:
|
||||
example["input_ids"] = None
|
||||
example["input_ids2"] = None
|
||||
# # for assertion
|
||||
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
|
||||
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
|
||||
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
|
||||
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
|
||||
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
|
||||
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
|
||||
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
|
||||
|
||||
# if one of alpha_masks is not None, we need to replace None with ones
|
||||
none_or_not = [x is None for x in alpha_mask_list]
|
||||
@@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
self,
|
||||
subsets: Sequence[DreamBoothSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
@@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
prior_loss_weight: float,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"
|
||||
|
||||
@@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
# new caching: get image size from cache files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
if strategy is not None:
|
||||
logger.info("get image size from cache files")
|
||||
logger.info("get image size from name of cache files")
|
||||
size_set_count = 0
|
||||
for i, img_path in enumerate(tqdm(img_paths)):
|
||||
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
|
||||
if w is not None and h is not None:
|
||||
sizes[i] = [w, h]
|
||||
size_set_count += 1
|
||||
@@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset):
|
||||
self,
|
||||
subsets: Sequence[FineTuningSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
@@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: bool,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
@@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset):
|
||||
self,
|
||||
subsets: Sequence[ControlNetSubset],
|
||||
batch_size: int,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier: float,
|
||||
enable_bucket: bool,
|
||||
@@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset):
|
||||
bucket_no_upscale: bool,
|
||||
debug_dataset: float,
|
||||
) -> None:
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
db_subsets = []
|
||||
for subset in subsets:
|
||||
@@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset):
|
||||
self.dreambooth_dataset_delegate = DreamBoothDataset(
|
||||
db_subsets,
|
||||
batch_size,
|
||||
tokenizer,
|
||||
max_token_length,
|
||||
resolution,
|
||||
network_multiplier,
|
||||
enable_bucket,
|
||||
@@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset):
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
def set_current_strategies(self):
|
||||
return self.dreambooth_dataset_delegate.set_current_strategies()
|
||||
|
||||
def make_buckets(self):
|
||||
self.dreambooth_dataset_delegate.make_buckets()
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
@@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset):
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
|
||||
def __len__(self):
|
||||
return self.dreambooth_dataset_delegate.__len__()
|
||||
|
||||
@@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
|
||||
def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy):
|
||||
"""
|
||||
DataLoader is run in multiple processes, so we need to set the strategy manually.
|
||||
"""
|
||||
for dataset in self.datasets:
|
||||
dataset.set_text_encoder_output_caching_strategy(strategy)
|
||||
|
||||
def enable_XTI(self, *args, **kwargs):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
@@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||
|
||||
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
|
||||
def new_cache_latents(self, model: Any, is_main_process: bool):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(is_main_process, strategy)
|
||||
dataset.new_cache_latents(model, is_main_process)
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
@@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||
)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
|
||||
def set_caching_mode(self, caching_mode):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_caching_mode(caching_mode)
|
||||
@@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
|
||||
def set_current_strategies(self):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_strategies()
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_epoch(epoch)
|
||||
@@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
|
||||
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
# TODO update to use CachingStrategy
|
||||
def load_latents_from_disk(
|
||||
npz_path,
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
# def load_latents_from_disk(
|
||||
# npz_path,
|
||||
# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
# npz = np.load(npz_path)
|
||||
# if "latents" not in npz:
|
||||
# raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
# latents = npz["latents"]
|
||||
# original_size = npz["original_size"].tolist()
|
||||
# crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||
kwargs = {}
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_ltrb=np.array(crop_ltrb),
|
||||
**kwargs,
|
||||
)
|
||||
# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||
# kwargs = {}
|
||||
# if flipped_latents_tensor is not None:
|
||||
# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
# if alpha_mask is not None:
|
||||
# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
# np.savez(
|
||||
# npz_path,
|
||||
# latents=latents_tensor.float().cpu().numpy(),
|
||||
# original_size=np.array(original_size),
|
||||
# crop_ltrb=np.array(crop_ltrb),
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
@@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
example = train_dataset[idx]
|
||||
if example["latents"] is not None:
|
||||
logger.info(f"sample has latents from npz file: {example['latents'].size()}")
|
||||
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||
for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate(
|
||||
zip(
|
||||
example["image_keys"],
|
||||
example["captions"],
|
||||
example["loss_weights"],
|
||||
example["input_ids"],
|
||||
# example["input_ids"],
|
||||
example["original_sizes_hw"],
|
||||
example["crop_top_lefts"],
|
||||
example["target_sizes_hw"],
|
||||
@@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
||||
if "network_multipliers" in example:
|
||||
print(f"network multiplier: {example['network_multipliers'][j]}")
|
||||
|
||||
if show_input_ids:
|
||||
logger.info(f"input ids: {iid}")
|
||||
if "input_ids2" in example:
|
||||
logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||
# if show_input_ids:
|
||||
# logger.info(f"input ids: {iid}")
|
||||
# if "input_ids2" in example:
|
||||
# logger.info(f"input ids2: {example['input_ids2'][j]}")
|
||||
if example["images"] is not None:
|
||||
im = example["images"][j]
|
||||
logger.info(f"image size: {im.size()}")
|
||||
@@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive):
|
||||
|
||||
|
||||
class MinimalDataset(BaseDataset):
|
||||
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
|
||||
def __init__(self, resolution, network_multiplier, debug_dataset=False):
|
||||
super().__init__(resolution, network_multiplier, debug_dataset)
|
||||
|
||||
self.num_train_images = 0 # update in subclass
|
||||
self.num_reg_images = 0 # update in subclass
|
||||
@@ -2773,14 +2783,15 @@ def cache_batch_latents(
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_latents_to_disk(
|
||||
info.latents_npz,
|
||||
latent,
|
||||
info.latents_original_size,
|
||||
info.latents_crop_ltrb,
|
||||
flipped_latent,
|
||||
alpha_mask,
|
||||
)
|
||||
# save_latents_to_disk(
|
||||
# info.latents_npz,
|
||||
# latent,
|
||||
# info.latents_original_size,
|
||||
# info.latents_crop_ltrb,
|
||||
# flipped_latent,
|
||||
# alpha_mask,
|
||||
# )
|
||||
pass
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
@@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(args: argparse.Namespace):
|
||||
logger.info("prepare tokenizer")
|
||||
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
|
||||
|
||||
tokenizer: CLIPTokenizer = None
|
||||
if args.tokenizer_cache_dir:
|
||||
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
|
||||
if os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
|
||||
|
||||
if tokenizer is None:
|
||||
if args.v2:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
|
||||
else:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(original_path)
|
||||
|
||||
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
||||
logger.info(f"update token length: {args.max_token_length}")
|
||||
|
||||
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
|
||||
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
|
||||
tokenizer.save_pretrained(local_tokenizer_path)
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def prepare_accelerator(args: argparse.Namespace):
|
||||
"""
|
||||
this function also prepares deepspeed plugin
|
||||
@@ -5550,6 +5534,7 @@ def sample_images_common(
|
||||
):
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
TODO Use strategies here
|
||||
"""
|
||||
|
||||
if steps == 0:
|
||||
|
||||
Reference in New Issue
Block a user