Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -104,8 +104,6 @@ class ControlNetSubsetParams(BaseSubsetParams):
@dataclass
class BaseDatasetParams:
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
max_token_length: int = None
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False

View File

@@ -38,7 +38,7 @@ class SDTokenizer:
サブクラスで各種の設定を行ってる。このクラスはその設定に基づき重み付きのトークン化を行うようだ。
Some settings are done in subclasses. This class seems to perform tokenization with weights based on those settings.
"""
self.tokenizer = tokenizer
self.tokenizer: CLIPTokenizer = tokenizer
self.max_length = max_length
self.min_length = min_length
empty = self.tokenizer("")["input_ids"]
@@ -56,6 +56,19 @@ class SDTokenizer:
self.inv_vocab = {v: k for k, v in vocab.items()}
self.max_word_length = 8
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
"""
Tokenize the text without weights.
"""
if type(text) == str:
text = [text]
batch_tokens = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.max_length, return_tensors="pt")
# return tokens["input_ids"]
pad_token = self.end_token if self.pad_with_end else 0
for tokens in batch_tokens["input_ids"]:
assert tokens[0] == self.start_token, f"tokens[0]: {tokens[0]}, start_token: {self.start_token}"
def tokenize_with_weights(self, text: str, truncate_to_max_length=True, truncate_length=None):
"""Tokenize the text, with weight values - presume 1.0 for all and ignore other features here.
The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3."""
@@ -75,13 +88,14 @@ class SDTokenizer:
for word in to_tokenize:
batch.extend([(t, 1) for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1]])
batch.append((self.end_token, 1.0))
print(len(batch), self.max_length, self.min_length)
if self.pad_to_max_length:
batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch)))
if self.min_length is not None and len(batch) < self.min_length:
batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch)))
# truncate to max_length
# print(f"batch: {batch}, truncate: {truncate}, len(batch): {len(batch)}, max_length: {self.max_length}")
print(f"batch: {batch}, max_length: {self.max_length}, truncate: {truncate_to_max_length}, truncate_length: {truncate_length}")
if truncate_to_max_length and len(batch) > self.max_length:
batch = batch[: self.max_length]
if truncate_length is not None and len(batch) > truncate_length:
@@ -110,27 +124,38 @@ class SDXLClipGTokenizer(SDTokenizer):
class SD3Tokenizer:
def __init__(self, t5xxl=True):
def __init__(self, t5xxl=True, t5xxl_max_length: Optional[int] = 256):
if t5xxl_max_length is None:
t5xxl_max_length = 256
# TODO cache tokenizer settings locally or hold them in the repo like ComfyUI
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
self.clip_l = SDTokenizer(tokenizer=clip_tokenizer)
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
# self.clip_l = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
# self.clip_g = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
self.t5xxl = T5XXLTokenizer() if t5xxl else None
# t5xxl has 99999999 max length, clip has 77
self.model_max_length = self.clip_l.max_length # 77
self.t5xxl_max_length = t5xxl_max_length
def tokenize_with_weights(self, text: str):
# temporary truncate to max_length even for t5xxl
return (
self.clip_l.tokenize_with_weights(text),
self.clip_g.tokenize_with_weights(text),
(
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.model_max_length)
self.t5xxl.tokenize_with_weights(text, truncate_to_max_length=False, truncate_length=self.t5xxl_max_length)
if self.t5xxl is not None
else None
),
)
def tokenize(self, text: str):
return (
self.clip_l.tokenize(text),
self.clip_g.tokenize(text),
(self.t5xxl.tokenize(text) if self.t5xxl is not None else None),
)
# endregion
@@ -1474,7 +1499,10 @@ class ClipTokenWeightEncoder:
tokens = [a[0] for a in pairs[0]] # I'm not sure why this is [0]
list_of_tokens.append(tokens)
else:
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
if isinstance(list_of_token_weight_pairs[0], torch.Tensor):
list_of_tokens = [list(list_of_token_weight_pairs[0])]
else:
list_of_tokens = [[a[0] for a in list_of_token_weight_pairs[0]]]
out, pooled = self(list_of_tokens)
if has_batch:
@@ -1614,9 +1642,9 @@ class T5XXLModel(SDClipModel):
### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl
#################################################################################################
"""
class T5XXLTokenizer(SDTokenizer):
"""Wraps the T5 Tokenizer from HF into the SDTokenizer interface"""
""Wraps the T5 Tokenizer from HF into the SDTokenizer interface""
def __init__(self):
super().__init__(
@@ -1627,6 +1655,7 @@ class T5XXLTokenizer(SDTokenizer):
max_length=99999999,
min_length=77,
)
"""
class T5LayerNorm(torch.nn.Module):

View File

@@ -280,111 +280,6 @@ def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.vae = None
def set_vae(self, vae: sd3_models.SDVAE):
self.vae = vae
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
with torch.no_grad():
latents_tensors = self.vae.encode(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = self.vae.encode(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
if self.cache_to_disk:
kwargs = {}
if flipped_latent is not None:
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
info.latents_npz,
latents=latents.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
else:
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
if not train_util.HIGH_VRAM:
clean_memory_on_device(self.vae.device)
# region Diffusers

View File

@@ -384,6 +384,7 @@ def get_cond(
dtype: Optional[torch.dtype] = None,
):
l_tokens, g_tokens, t5_tokens = tokenizer.tokenize_with_weights(prompt)
print(t5_tokens)
return get_cond_from_tokens(l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, device=device, dtype=dtype)

View File

@@ -327,7 +327,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
)
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
def add_sdxl_training_arguments(parser: argparse.ArgumentParser, support_text_encoder_caching: bool = True):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)

328
library/strategy_base.py Normal file
View File

@@ -0,0 +1,328 @@
# base class for platform strategies. this file defines the interface for strategies
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TokenizeStrategy"]:
return cls._strategy
def _load_tokenizer(
self, model_class: Any, model_id: str, subfolder: Optional[str] = None, tokenizer_cache_dir: Optional[str] = None
) -> Any:
tokenizer = None
if tokenizer_cache_dir:
local_tokenizer_path = os.path.join(tokenizer_cache_dir, model_id.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = model_class.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
tokenizer = model_class.from_pretrained(model_id, subfolder=subfolder)
if tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
raise NotImplementedError
def _get_input_ids(self, tokenizer: CLIPTokenizer, text: str, max_length: Optional[int] = None) -> torch.Tensor:
"""
for SD1.5/2.0/SDXL
TODO support batch input
"""
if max_length is None:
max_length = tokenizer.model_max_length - 2
input_ids = tokenizer(text, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt").input_ids
if max_length > tokenizer.model_max_length:
input_ids = input_ids.squeeze(0)
iids_list = []
if tokenizer.pad_token_id == tokenizer.eos_token_id:
# v1
# 77以上の時は "<BOS> .... <EOS> <EOS> <EOS>" でトータル227とかになっているので、"<BOS>...<EOS>"の三連に変換する
# 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2): # (1, 152, 75)
ids_chunk = (
input_ids[0].unsqueeze(0),
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
)
ids_chunk = torch.cat(ids_chunk)
iids_list.append(ids_chunk)
else:
# v2 or SDXL
# 77以上の時は "<BOS> .... <EOS> <PAD> <PAD>..." でトータル227とかになっているので、"<BOS>...<EOS> <PAD> <PAD> ..."の三連に変換する
for i in range(1, max_length - tokenizer.model_max_length + 2, tokenizer.model_max_length - 2):
ids_chunk = (
input_ids[0].unsqueeze(0), # BOS
input_ids[i : i + tokenizer.model_max_length - 2],
input_ids[-1].unsqueeze(0),
) # PAD or EOS
ids_chunk = torch.cat(ids_chunk)
# 末尾が <EOS> <PAD> または <PAD> <PAD> の場合は、何もしなくてよい
# 末尾が x <PAD/EOS> の場合は末尾を <EOS> に変えるx <EOS> なら結果的に変化なし)
if ids_chunk[-2] != tokenizer.eos_token_id and ids_chunk[-2] != tokenizer.pad_token_id:
ids_chunk[-1] = tokenizer.eos_token_id
# 先頭が <BOS> <PAD> ... の場合は <BOS> <EOS> <PAD> ... に変える
if ids_chunk[1] == tokenizer.pad_token_id:
ids_chunk[1] = tokenizer.eos_token_id
iids_list.append(ids_chunk)
input_ids = torch.stack(iids_list) # 3,77
return input_ids
class TextEncodingStrategy:
_strategy = None # strategy instance: actual strategy class
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncodingStrategy"]:
return cls._strategy
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Encode tokens into embeddings and outputs.
:param tokens: list of token tensors for each TextModel
:return: list of output embeddings for each architecture
"""
raise NotImplementedError
class TextEncoderOutputsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._is_partial = is_partial
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
@property
def is_partial(self):
return self._is_partial
def get_outputs_npz_path(self, image_abs_path: str) -> str:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
):
raise NotImplementedError
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
raise NotImplementedError
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
def _defualt_is_disk_cached_latents_expected(
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
"""
from library import train_util # import here to avoid circular import
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)
with torch.no_grad():
latents_tensors = encode_by_vae(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = encode_by_vae(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
if self.cache_to_disk:
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)

139
library/strategy_sd.py Normal file
View File

@@ -0,0 +1,139 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
class SdTokenizeStrategy(TokenizeStrategy):
def __init__(self, v2: bool, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length does not include <BOS> and <EOS> (None, 75, 150, 225)
"""
logger.info(f"Using {'v2' if v2 else 'v1'} tokenizer")
if v2:
self.tokenizer = self._load_tokenizer(
CLIPTokenizer, V2_STABLE_DIFFUSION_ID, subfolder="tokenizer", tokenizer_cache_dir=tokenizer_cache_dir
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
# tokens: b,n,77
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.sd = sd
self.suffix = (
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
# does not include old npz
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + self.suffix)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
# support old .npz
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
if os.path.exists(old_npz_file):
return old_npz_file
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

229
library/strategy_sd3.py Normal file
View File

@@ -0,0 +1,229 @@
import os
import glob
from typing import Any, List, Optional, Tuple, Union
import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
class Sd3TokenizeStrategy(TokenizeStrategy):
def __init__(self, t5xxl_max_length: int = 256, tokenizer_cache_dir: Optional[str] = None) -> None:
self.t5xxl_max_length = t5xxl_max_length
self.clip_l = self._load_tokenizer(CLIPTokenizer, CLIP_L_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g = self._load_tokenizer(CLIPTokenizer, CLIP_G_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.t5xxl = self._load_tokenizer(T5TokenizerFast, T5_XXL_TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
self.clip_g.pad_token_id = 0 # use 0 as pad token for clip_g
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
l_tokens = self.clip_l(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens = self.clip_g(text, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens = self.t5xxl(text, max_length=self.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt")
l_tokens = l_tokens["input_ids"]
g_tokens = g_tokens["input_ids"]
t5_tokens = t5_tokens["input_ids"]
return [l_tokens, g_tokens, t5_tokens]
class Sd3TextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
clip_l, clip_g, t5xxl = models
l_tokens, g_tokens, t5_tokens = tokens
if l_tokens is None:
assert g_tokens is None, "g_tokens must be None if l_tokens is None"
lg_out = None
else:
assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
l_out, l_pooled = clip_l(l_tokens)
g_out, g_pooled = clip_g(g_tokens)
lg_out = torch.cat([l_out, g_out], dim=-1)
if t5xxl is not None and t5_tokens is not None:
t5_out, _ = t5xxl(t5_tokens) # t5_out is [1, max length, 4096]
else:
t5_out = None
lg_pooled = torch.cat((l_pooled, g_pooled), dim=-1) if l_tokens is not None else None
return [lg_out, t5_out, lg_pooled]
def concat_encodings(
self, lg_out: torch.Tensor, t5_out: Optional[torch.Tensor], lg_pooled: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1]))
if t5_out is None:
t5_out = torch.zeros((lg_out.shape[0], 77, 4096), device=lg_out.device, dtype=lg_out.dtype)
return torch.cat([lg_out, t5_out], dim=-2), lg_pooled
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, abs_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(self.get_outputs_npz_path(abs_path))
if "clip_l" not in npz or "clip_g" not in npz:
return False
if "clip_l_pool" not in npz or "clip_g_pool" not in npz:
return False
# t5xxl is optional
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"] if "t5_out" in data else None
return [lg_out, t5_out, lg_pooled]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
captions = [info.caption for info in infos]
clip_l_tokens, clip_g_tokens, t5xxl_tokens = tokenize_strategy.tokenize(captions)
with torch.no_grad():
lg_out, t5_out, lg_pooled = text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [clip_l_tokens, clip_g_tokens, t5xxl_tokens]
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out is not None and t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
if t5_out is not None:
t5_out = t5_out.cpu().numpy()
for i, info in enumerate(infos):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i] if t5_out is not None else None
lg_pooled_i = lg_pooled[i]
if self.cache_to_disk:
kwargs = {}
if t5_out is not None:
kwargs["t5_out"] = t5_out_i
np.savez(info.text_encoder_outputs_npz, lg_out=lg_out_i, lg_pooled=lg_pooled_i, **kwargs)
else:
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i)
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
def get_image_size_from_disk_cache_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._defualt_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)
if __name__ == "__main__":
# test code for Sd3TokenizeStrategy
# tokenizer = sd3_models.SD3Tokenizer()
strategy = Sd3TokenizeStrategy(256)
text = "hello world"
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
# print(l_tokens.shape)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
texts = ["hello world", "the quick brown fox jumps over the lazy dog"]
l_tokens_2 = strategy.clip_l(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
g_tokens_2 = strategy.clip_g(texts, max_length=77, padding="max_length", truncation=True, return_tensors="pt")
t5_tokens_2 = strategy.t5xxl(
texts, max_length=strategy.t5xxl_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
print(l_tokens_2)
print(g_tokens_2)
print(t5_tokens_2)
# compare
print(torch.allclose(l_tokens, l_tokens_2["input_ids"][0]))
print(torch.allclose(g_tokens, g_tokens_2["input_ids"][0]))
print(torch.allclose(t5_tokens, t5_tokens_2["input_ids"][0]))
text = ",".join(["hello world! this is long text"] * 50)
l_tokens, g_tokens, t5_tokens = strategy.tokenize(text)
print(l_tokens)
print(g_tokens)
print(t5_tokens)
print(f"model max length l: {strategy.clip_l.model_max_length}")
print(f"model max length g: {strategy.clip_g.model_max_length}")
print(f"model max length t5: {strategy.t5xxl.model_max_length}")

247
library/strategy_sdxl.py Normal file
View File

@@ -0,0 +1,247 @@
import os
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
if max_length is None:
self.max_length = self.tokenizer1.model_max_length
else:
self.max_length = max_length + 2
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return (
torch.stack([self._get_input_ids(self.tokenizer1, t, self.max_length) for t in text], dim=0),
torch.stack([self._get_input_ids(self.tokenizer2, t, self.max_length) for t in text], dim=0),
)
class SdxlTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
pass
def _pool_workaround(
self, text_encoder: CLIPTextModelWithProjection, last_hidden_state: torch.Tensor, input_ids: torch.Tensor, eos_token_id: int
):
r"""
workaround for CLIP's pooling bug: it returns the hidden states for the max token id as the pooled output
instead of the hidden states for the EOS token
If we use Textual Inversion, we need to use the hidden states for the EOS token as the pooled output
Original code from CLIP's pooling function:
\# text_embeds.shape = [batch_size, sequence_length, transformer.width]
\# take features from the eot embedding (eot_token is the highest number in each sequence)
\# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
"""
# input_ids: b*n,77
# find index for EOS token
# Following code is not working if one of the input_ids has multiple EOS tokens (very odd case)
# eos_token_index = torch.where(input_ids == eos_token_id)[1]
# eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# Create a mask where the EOS tokens are
eos_token_mask = (input_ids == eos_token_id).int()
# Use argmax to find the last index of the EOS token for each element in the batch
eos_token_index = torch.argmax(eos_token_mask, dim=1) # this will be 0 if there is no EOS token, it's fine
eos_token_index = eos_token_index.to(device=last_hidden_state.device)
# get hidden states for EOS token
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), eos_token_index
]
# apply projection: projection may be of different dtype than last_hidden_state
pooled_output = text_encoder.text_projection(pooled_output.to(text_encoder.text_projection.weight.dtype))
pooled_output = pooled_output.to(last_hidden_state.dtype)
return pooled_output
def _get_hidden_states_sdxl(
self,
input_ids1: torch.Tensor,
input_ids2: torch.Tensor,
tokenizer1: CLIPTokenizer,
tokenizer2: CLIPTokenizer,
text_encoder1: Union[CLIPTextModel, torch.nn.Module],
text_encoder2: Union[CLIPTextModelWithProjection, torch.nn.Module],
unwrapped_text_encoder2: Optional[CLIPTextModelWithProjection] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
max_token_length = input_ids1.size()[1] * input_ids1.size()[2]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
input_ids1 = input_ids1.to(text_encoder1.device)
input_ids2 = input_ids2.to(text_encoder2.device)
# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]
# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
unwrapped_text_encoder2 = unwrapped_text_encoder2 or text_encoder2
pool2 = self._pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
if max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)
# pool はnの最初のものを使う
pool2 = pool2[::n_size]
return hidden_states1, hidden_states2, pool2
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
"""
Args:
tokenize_strategy: TokenizeStrategy
models: List of models, [text_encoder1, text_encoder2, unwrapped text_encoder2 (optional)]
tokens: List of tokens, for text_encoder1 and text_encoder2
"""
if len(models) == 2:
text_encoder1, text_encoder2 = models
unwrapped_text_encoder2 = None
else:
text_encoder1, text_encoder2, unwrapped_text_encoder2 = models
tokens1, tokens2 = tokens
sdxl_tokenize_strategy = tokenize_strategy # type: SdxlTokenizeStrategy
tokenizer1, tokenizer2 = sdxl_tokenize_strategy.tokenizer1, sdxl_tokenize_strategy.tokenizer2
hidden_states1, hidden_states2, pool2 = self._get_hidden_states_sdxl(
tokens1, tokens2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, unwrapped_text_encoder2
)
return [hidden_states1, hidden_states2, pool2]
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool, is_partial: bool = False
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(self, abs_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(self.get_outputs_npz_path(abs_path)):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(self.get_outputs_npz_path(abs_path))
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {self.get_outputs_npz_path(abs_path)}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [info.caption for info in infos]
tokens1, tokens2 = tokenize_strategy.tokenize(captions)
with torch.no_grad():
hidden_state1, hidden_state2, pool2 = sdxl_text_encoding_strategy.encode_tokens(
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]

View File

@@ -12,6 +12,7 @@ import re
import shutil
import time
from typing import (
Any,
Dict,
List,
NamedTuple,
@@ -34,6 +35,7 @@ from tqdm import tqdm
import torch
from library.device_utils import init_ipex, clean_memory_on_device
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy, TextEncodingStrategy
init_ipex()
@@ -81,10 +83,6 @@ logger = logging.getLogger(__name__)
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
HIGH_VRAM = False
# checkpointファイル名
@@ -148,18 +146,24 @@ class ImageInfo:
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
self.latents_crop_ltrb: Tuple[int, int] = None # crop left top right bottom in original pixel size, not latents size
self.cond_img_path: str = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_npz: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
None # crop left top right bottom in original pixel size, not latents size
)
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
# SDXL, optional
self.text_encoder_outputs_npz: Optional[str] = None
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
@@ -359,47 +363,6 @@ class AugHelper:
return self.color_aug if use_color_aug else None
class LatentsCachingStrategy:
_strategy = None # strategy instance: actual strategy class
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
@classmethod
def set_strategy(cls, strategy):
if cls._strategy is not None:
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
cls._strategy = strategy
@classmethod
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
return cls._strategy
@property
def cache_to_disk(self):
return self._cache_to_disk
@property
def batch_size(self):
return self._batch_size
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
raise NotImplementedError
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
raise NotImplementedError
def is_disk_cached_latents_expected(
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
) -> bool:
raise NotImplementedError
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
raise NotImplementedError
class BaseSubset:
def __init__(
self,
@@ -639,17 +602,12 @@ class ControlNetSubset(BaseSubset):
class BaseDataset(torch.utils.data.Dataset):
def __init__(
self,
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]],
max_token_length: int,
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
) -> None:
super().__init__()
self.tokenizers = tokenizer if isinstance(tokenizer, list) else [tokenizer]
self.max_token_length = max_token_length
# width/height is used when enable_bucket==False
self.width, self.height = (None, None) if resolution is None else resolution
self.network_multiplier = network_multiplier
@@ -670,8 +628,6 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_no_upscale = None
self.bucket_info = None # for metadata
self.tokenizer_max_length = self.tokenizers[0].model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.current_step: int = 0
@@ -690,6 +646,15 @@ class BaseDataset(torch.utils.data.Dataset):
# caching
self.caching_mode = None # None, 'latents', 'text'
self.tokenize_strategy = None
self.text_encoder_output_caching_strategy = None
self.latents_caching_strategy = None
def set_current_strategies(self):
self.tokenize_strategy = TokenizeStrategy.get_strategy()
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
def set_seed(self, seed):
self.seed = seed
@@ -979,22 +944,6 @@ class BaseDataset(torch.utils.data.Dataset):
for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
#  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
#
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# # そのためバッチサイズを画像種類までに制限する
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# # TO DO 正則化画像をepochまたがりで利用する仕組み
# num_of_image_types = len(set(bucket))
# bucket_batch_size = min(self.batch_size, num_of_image_types)
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# # logger.info(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
# for batch_index in range(batch_count):
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
# ↑ここまで
self.shuffle_buckets()
self._length = len(self.buckets_indices)
@@ -1027,12 +976,13 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def new_cache_latents(self, is_main_process: bool, caching_strategy: LatentsCachingStrategy):
def new_cache_latents(self, model: Any, is_main_process: bool):
r"""
a brand new method to cache latents. This method caches latents with caching strategy.
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
"""
logger.info("caching latents with caching strategy.")
caching_strategy = LatentsCachingStrategy.get_strategy()
image_infos = list(self.image_data.values())
# sort by resolution
@@ -1088,7 +1038,7 @@ class BaseDataset(torch.utils.data.Dataset):
logger.info("caching latents...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_latents(batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_latents(model, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
@@ -1145,6 +1095,56 @@ class BaseDataset(torch.utils.data.Dataset):
for batch in tqdm(batches, smoothing=1, total=len(batches)):
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
r"""
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
"""
tokenize_strategy = TokenizeStrategy.get_strategy()
text_encoding_strategy = TextEncodingStrategy.get_strategy()
caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
batch_size = caching_strategy.batch_size or self.batch_size
# if cache to disk, don't cache TE outputs in non-main process
if caching_strategy.cache_to_disk and not is_main_process:
return
logger.info("caching Text Encoder outputs with caching strategy.")
image_infos = list(self.image_data.values())
# split by resolution
batches = []
batch = []
logger.info("checking cache validity...")
for info in tqdm(image_infos):
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
info.text_encoder_outputs_npz = te_out_npz
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
if cache_available: # do not add to batch
continue
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
if len(batches) == 0:
logger.info("no Text Encoder outputs to cache")
return
# iterate batches
logger.info("caching Text Encoder outputs...")
for batch in tqdm(batches, smoothing=1, total=len(batches)):
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
# to support SD1/2, it needs a flag for v2, but it is postponed
@@ -1188,6 +1188,8 @@ class BaseDataset(torch.utils.data.Dataset):
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
logger.info("caching text encoder outputs.")
tokenize_strategy = TokenizeStrategy.get_strategy()
if batch_size is None:
batch_size = self.batch_size
@@ -1229,7 +1231,7 @@ class BaseDataset(torch.utils.data.Dataset):
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
batch.append((info, input_ids1, input_ids2))
else:
l_tokens, g_tokens, t5_tokens = tokenizers[0].tokenize_with_weights(info.caption)
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
batch.append((info, l_tokens, g_tokens, t5_tokens))
if len(batch) >= batch_size:
@@ -1347,7 +1349,6 @@ class BaseDataset(torch.utils.data.Dataset):
loss_weights = []
captions = []
input_ids_list = []
input_ids2_list = []
latents_list = []
alpha_mask_list = []
images = []
@@ -1355,16 +1356,14 @@ class BaseDataset(torch.utils.data.Dataset):
crop_top_lefts = []
target_sizes_hw = []
flippeds = [] # 変数名が微妙
text_encoder_outputs1_list = []
text_encoder_outputs2_list = []
text_encoder_pool2_list = []
text_encoder_outputs_list = []
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
loss_weights.append(
self.prior_loss_weight if image_info.is_reg else 1.0
) # in case of fine tuning, is_reg is always False
# in case of fine tuning, is_reg is always False
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
flipped = subset.flip_aug and random.random() < 0.5 # not flipped or flipped with 50% chance
@@ -1381,7 +1380,9 @@ class BaseDataset(torch.utils.data.Dataset):
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz)
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
)
if flipped:
latents = flipped_latents
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
@@ -1470,75 +1471,67 @@ class BaseDataset(torch.utils.data.Dataset):
# captionとtext encoder outputを処理する
caption = image_info.caption # default
if image_info.text_encoder_outputs1 is not None:
text_encoder_outputs1_list.append(image_info.text_encoder_outputs1)
text_encoder_outputs2_list.append(image_info.text_encoder_outputs2)
text_encoder_pool2_list.append(image_info.text_encoder_pool2)
captions.append(caption)
tokenization_required = (
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
)
text_encoder_outputs = None
input_ids = None
if image_info.text_encoder_outputs is not None:
# cached
text_encoder_outputs = image_info.text_encoder_outputs
elif image_info.text_encoder_outputs_npz is not None:
text_encoder_outputs1, text_encoder_outputs2, text_encoder_pool2 = load_text_encoder_outputs_from_disk(
# on disk
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
)
text_encoder_outputs1_list.append(text_encoder_outputs1)
text_encoder_outputs2_list.append(text_encoder_outputs2)
text_encoder_pool2_list.append(text_encoder_pool2)
captions.append(caption)
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
caption = self.process_caption(subset, image_info.caption)
if self.XTI_layers:
caption_layer = []
for layer in self.XTI_layers:
token_strings_from = " ".join(self.token_strings)
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
caption_ = caption.replace(token_strings_from, token_strings_to)
caption_layer.append(caption_)
captions.append(caption_layer)
else:
captions.append(caption)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
# token_strings_from = " ".join(self.token_strings)
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
# caption_ = caption.replace(token_strings_from, token_strings_to)
# caption_layer.append(caption_)
# captions.append(caption_layer)
# else:
# captions.append(caption)
if not self.token_padding_disabled: # this option might be omitted in future
# TODO get_input_ids must support SD3
if self.XTI_layers:
token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
else:
token_caption = self.get_input_ids(caption, self.tokenizers[0])
input_ids_list.append(token_caption)
# if not self.token_padding_disabled: # this option might be omitted in future
# # TODO get_input_ids must support SD3
# if self.XTI_layers:
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
# else:
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
# input_ids_list.append(token_caption)
if len(self.tokenizers) > 1:
if self.XTI_layers:
token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
else:
token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
input_ids2_list.append(token_caption2)
# if len(self.tokenizers) > 1:
# if self.XTI_layers:
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
# else:
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
# input_ids2_list.append(token_caption2)
input_ids_list.append(input_ids)
captions.append(caption)
def none_or_stack_elements(tensors_list, converter):
# [[clip_l, clip_g, t5xxl], [clip_l, clip_g, t5xxl], ...] -> [torch.stack(clip_l), torch.stack(clip_g), torch.stack(t5xxl)]
if len(tensors_list) == 0 or tensors_list[0] == None or len(tensors_list[0]) == 0 or tensors_list[0][0] is None:
return None
return [torch.stack([converter(x[i]) for x in tensors_list]) for i in range(len(tensors_list[0]))]
example = {}
example["loss_weights"] = torch.FloatTensor(loss_weights)
if len(text_encoder_outputs1_list) == 0:
if self.token_padding_disabled:
# padding=True means pad in the batch
example["input_ids"] = self.tokenizer[0](captions, padding=True, truncation=True, return_tensors="pt").input_ids
if len(self.tokenizers) > 1:
example["input_ids2"] = self.tokenizer[1](
captions, padding=True, truncation=True, return_tensors="pt"
).input_ids
else:
example["input_ids2"] = None
else:
example["input_ids"] = torch.stack(input_ids_list)
example["input_ids2"] = torch.stack(input_ids2_list) if len(self.tokenizers) > 1 else None
example["text_encoder_outputs1_list"] = None
example["text_encoder_outputs2_list"] = None
example["text_encoder_pool2_list"] = None
else:
example["input_ids"] = None
example["input_ids2"] = None
# # for assertion
# example["input_ids"] = torch.stack([self.get_input_ids(cap, self.tokenizers[0]) for cap in captions])
# example["input_ids2"] = torch.stack([self.get_input_ids(cap, self.tokenizers[1]) for cap in captions])
example["text_encoder_outputs1_list"] = torch.stack(text_encoder_outputs1_list)
example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list)
example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list)
example["text_encoder_outputs_list"] = none_or_stack_elements(text_encoder_outputs_list, torch.FloatTensor)
example["input_ids_list"] = none_or_stack_elements(input_ids_list, lambda x: x)
# if one of alpha_masks is not None, we need to replace None with ones
none_or_not = [x is None for x in alpha_mask_list]
@@ -1652,8 +1645,6 @@ class DreamBoothDataset(BaseDataset):
self,
subsets: Sequence[DreamBoothSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
@@ -1664,7 +1655,7 @@ class DreamBoothDataset(BaseDataset):
prior_loss_weight: float,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -1750,10 +1741,10 @@ class DreamBoothDataset(BaseDataset):
# new caching: get image size from cache files
strategy = LatentsCachingStrategy.get_strategy()
if strategy is not None:
logger.info("get image size from cache files")
logger.info("get image size from name of cache files")
size_set_count = 0
for i, img_path in enumerate(tqdm(img_paths)):
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
w, h = strategy.get_image_size_from_disk_cache_path(img_path)
if w is not None and h is not None:
sizes[i] = [w, h]
size_set_count += 1
@@ -1886,8 +1877,6 @@ class FineTuningDataset(BaseDataset):
self,
subsets: Sequence[FineTuningSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
@@ -1897,7 +1886,7 @@ class FineTuningDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size
@@ -2111,8 +2100,6 @@ class ControlNetDataset(BaseDataset):
self,
subsets: Sequence[ControlNetSubset],
batch_size: int,
tokenizer,
max_token_length,
resolution,
network_multiplier: float,
enable_bucket: bool,
@@ -2122,7 +2109,7 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: float,
) -> None:
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset)
db_subsets = []
for subset in subsets:
@@ -2160,8 +2147,6 @@ class ControlNetDataset(BaseDataset):
self.dreambooth_dataset_delegate = DreamBoothDataset(
db_subsets,
batch_size,
tokenizer,
max_token_length,
resolution,
network_multiplier,
enable_bucket,
@@ -2221,6 +2206,9 @@ class ControlNetDataset(BaseDataset):
self.conditioning_image_transforms = IMAGE_TRANSFORMS
def set_current_strategies(self):
return self.dreambooth_dataset_delegate.set_current_strategies()
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
@@ -2229,6 +2217,12 @@ class ControlNetDataset(BaseDataset):
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def new_cache_latents(self, model: Any, is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_latents(model, is_main_process)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
def __len__(self):
return self.dreambooth_dataset_delegate.__len__()
@@ -2314,6 +2308,13 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# for dataset in self.datasets:
# dataset.make_buckets()
def set_text_encoder_output_caching_strategy(self, strategy: TextEncoderOutputsCachingStrategy):
"""
DataLoader is run in multiple processes, so we need to set the strategy manually.
"""
for dataset in self.datasets:
dataset.set_text_encoder_output_caching_strategy(strategy)
def enable_XTI(self, *args, **kwargs):
for dataset in self.datasets:
dataset.enable_XTI(*args, **kwargs)
@@ -2323,10 +2324,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
logger.info(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
def new_cache_latents(self, is_main_process: bool, strategy: LatentsCachingStrategy):
def new_cache_latents(self, model: Any, is_main_process: bool):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_latents(is_main_process, strategy)
dataset.new_cache_latents(model, is_main_process)
def cache_text_encoder_outputs(
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
@@ -2344,6 +2345,11 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
)
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
for i, dataset in enumerate(self.datasets):
logger.info(f"[Dataset {i}]")
dataset.new_cache_text_encoder_outputs(models, is_main_process)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
@@ -2358,6 +2364,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def is_text_encoder_output_cacheable(self) -> bool:
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
def set_current_strategies(self):
for dataset in self.datasets:
dataset.set_current_strategies()
def set_current_epoch(self, epoch):
for dataset in self.datasets:
dataset.set_current_epoch(epoch)
@@ -2411,34 +2421,34 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
# TODO update to use CachingStrategy
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
# def load_latents_from_disk(
# npz_path,
# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
# npz = np.load(npz_path)
# if "latents" not in npz:
# raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
# latents = npz["latents"]
# original_size = npz["original_size"].tolist()
# crop_ltrb = npz["crop_ltrb"].tolist()
# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
# kwargs = {}
# if flipped_latents_tensor is not None:
# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
# if alpha_mask is not None:
# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
# np.savez(
# npz_path,
# latents=latents_tensor.float().cpu().numpy(),
# original_size=np.array(original_size),
# crop_ltrb=np.array(crop_ltrb),
# **kwargs,
# )
def debug_dataset(train_dataset, show_input_ids=False):
@@ -2465,12 +2475,12 @@ def debug_dataset(train_dataset, show_input_ids=False):
example = train_dataset[idx]
if example["latents"] is not None:
logger.info(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid, orgsz, crptl, trgsz, flpdz) in enumerate(
for j, (ik, cap, lw, orgsz, crptl, trgsz, flpdz) in enumerate(
zip(
example["image_keys"],
example["captions"],
example["loss_weights"],
example["input_ids"],
# example["input_ids"],
example["original_sizes_hw"],
example["crop_top_lefts"],
example["target_sizes_hw"],
@@ -2483,10 +2493,10 @@ def debug_dataset(train_dataset, show_input_ids=False):
if "network_multipliers" in example:
print(f"network multiplier: {example['network_multipliers'][j]}")
if show_input_ids:
logger.info(f"input ids: {iid}")
if "input_ids2" in example:
logger.info(f"input ids2: {example['input_ids2'][j]}")
# if show_input_ids:
# logger.info(f"input ids: {iid}")
# if "input_ids2" in example:
# logger.info(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
logger.info(f"image size: {im.size()}")
@@ -2555,8 +2565,8 @@ def glob_images_pathlib(dir_path, recursive):
class MinimalDataset(BaseDataset):
def __init__(self, tokenizer, max_token_length, resolution, network_multiplier, debug_dataset=False):
super().__init__(tokenizer, max_token_length, resolution, network_multiplier, debug_dataset)
def __init__(self, resolution, network_multiplier, debug_dataset=False):
super().__init__(resolution, network_multiplier, debug_dataset)
self.num_train_images = 0 # update in subclass
self.num_reg_images = 0 # update in subclass
@@ -2773,14 +2783,15 @@ def cache_batch_latents(
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(
info.latents_npz,
latent,
info.latents_original_size,
info.latents_crop_ltrb,
flipped_latent,
alpha_mask,
)
# save_latents_to_disk(
# info.latents_npz,
# latent,
# info.latents_original_size,
# info.latents_crop_ltrb,
# flipped_latent,
# alpha_mask,
# )
pass
else:
info.latents = latent
if flip_aug:
@@ -4662,33 +4673,6 @@ def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool):
)
def load_tokenizer(args: argparse.Namespace):
logger.info("prepare tokenizer")
original_path = V2_STABLE_DIFFUSION_PATH if args.v2 else TOKENIZER_PATH
tokenizer: CLIPTokenizer = None
if args.tokenizer_cache_dir:
local_tokenizer_path = os.path.join(args.tokenizer_cache_dir, original_path.replace("/", "_"))
if os.path.exists(local_tokenizer_path):
logger.info(f"load tokenizer from cache: {local_tokenizer_path}")
tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path) # same for v1 and v2
if tokenizer is None:
if args.v2:
tokenizer = CLIPTokenizer.from_pretrained(original_path, subfolder="tokenizer")
else:
tokenizer = CLIPTokenizer.from_pretrained(original_path)
if hasattr(args, "max_token_length") and args.max_token_length is not None:
logger.info(f"update token length: {args.max_token_length}")
if args.tokenizer_cache_dir and not os.path.exists(local_tokenizer_path):
logger.info(f"save Tokenizer to cache: {local_tokenizer_path}")
tokenizer.save_pretrained(local_tokenizer_path)
return tokenizer
def prepare_accelerator(args: argparse.Namespace):
"""
this function also prepares deepspeed plugin
@@ -5550,6 +5534,7 @@ def sample_images_common(
):
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
TODO Use strategies here
"""
if steps == 0: