Text Encoder cache (WIP)

This commit is contained in:
Kohya S
2024-11-27 12:57:04 +09:00
parent bdac55ebbc
commit 3677094256
15 changed files with 628 additions and 471 deletions

View File

@@ -151,15 +151,20 @@ def train(args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
)
)
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
train_dataset_group.set_current_strategies()
@@ -236,7 +241,12 @@ def train(args):
t5xxl.to(accelerator.device)
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
False,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

View File

@@ -10,8 +10,6 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
from library.utils import setup_logging
setup_logging()
@@ -19,6 +17,9 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -174,13 +175,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None

View File

@@ -10,10 +10,6 @@ import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
# TODO remove circular import by moving ImageInfo to a separate file
# from library.train_util import ImageInfo
from library import utils
from library.utils import setup_logging
setup_logging()
@@ -21,6 +17,8 @@ import logging
logger = logging.getLogger(__name__)
from library import utils
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
if dtype is None:
@@ -43,6 +41,58 @@ def get_available_dtypes() -> List[torch.dtype]:
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
"""
Removes lower precision values from tensor_dict.
"""
available_dtypes = get_available_dtypes()
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
for key_without_dtype in keys_without_dtype:
available_itemsize = None
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
key = key_without_dtype + dtype_suffix
if key in tensor_dict:
if available_itemsize is None:
available_itemsize = dtype.itemsize
elif available_itemsize > dtype.itemsize:
# if higher precision latents are already cached, remove lower precision latents
del tensor_dict[key]
def get_compatible_dtype_keys(
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str]]:
"""
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
If the key is not found, it returns None.
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
:param dict_keys: set of keys in the dictionary
:param keys_without_dtype: list of keys without dtype suffix to check
:param dtype: dtype to check, or None for any dtype
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
"""
compatible_dtypes = get_compatible_dtypes(dtype)
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
available_keys = []
for key_without_dtype in keys_without_dtype:
available_key = None
if key_without_dtype in dict_keys:
available_key = key_without_dtype
else:
for dtype_suffix in dtype_suffixes:
key = key_without_dtype + dtype_suffix
if key in dict_keys:
available_key = key
break
available_keys.append(available_key)
return available_keys
class TokenizeStrategy:
_strategy = None # strategy instance: actual strategy class
@@ -347,17 +397,26 @@ class TextEncoderOutputsCachingStrategy:
def __init__(
self,
architecture: str,
cache_to_disk: bool,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool = False,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
"""
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
"""
self._architecture = architecture
self._cache_to_disk = cache_to_disk
self._batch_size = batch_size
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
self._max_token_length = max_token_length
self._masked = masked
self._is_partial = is_partial
self._is_weighted = is_weighted
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
@classmethod
def set_strategy(cls, strategy):
@@ -369,6 +428,18 @@ class TextEncoderOutputsCachingStrategy:
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
return cls._strategy
@property
def architecture(self):
return self._architecture
@property
def max_token_length(self):
return self._max_token_length
@property
def masked(self):
return self._masked
@property
def cache_to_disk(self):
return self._cache_to_disk
@@ -377,6 +448,11 @@ class TextEncoderOutputsCachingStrategy:
def batch_size(self):
return self._batch_size
@property
def cache_suffix(self):
suffix_masked = "_m" if self.masked else ""
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
@property
def is_partial(self):
return self._is_partial
@@ -385,24 +461,145 @@ class TextEncoderOutputsCachingStrategy:
def is_weighted(self):
return self._is_weighted
def get_outputs_npz_path(self, image_abs_path: str) -> str:
def get_cache_path(self, absolute_path: str) -> str:
return os.path.splitext(absolute_path)[0] + self.cache_suffix
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
raise NotImplementedError
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
"""
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
all dtype tensors are returned, because cache validation is done in advance.
"""
with safe_open(cache_path, framework="pt") as f:
metadata = f.metadata()
version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, version.split("."))
if major > 1: # or (major == 1 and minor > 0):
if not self.load_version_warning_printed:
self.load_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
)
dict_keys = f.keys()
results = []
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
for key in compatible_keys:
results.append(f.get_tensor(key) if key is not None else None)
return results
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
raise NotImplementedError
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
raise NotImplementedError
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
"""
masked: may be False even if self.masked is True. It is False for some outputs.
"""
key_suffix = f"_{prompt_id}"
if dtype is not None and dtype.is_floating_point: # float tensor only
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
return key_suffix
def get_compatible_output_keys(
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
) -> list[Optional[str], Optional[str]]:
"""
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
"""
key_suffix = self.get_key_suffix(caption_index, None)
keys_without_dtype = [k + key_suffix for k in base_keys]
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
def _default_is_disk_cached_outputs_expected(
self,
cache_path: str,
captions: list[str],
base_keys: list[tuple[str, bool]],
preferred_dtype: Optional[Union[str, torch.dtype]],
):
if not self.cache_to_disk:
return False
if not os.path.exists(cache_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
with utils.MemoryEfficientSafeOpen(cache_path) as f:
keys = f.keys()
metadata = f.metadata()
# check captions in metadata
for i, caption in enumerate(captions):
if metadata.get(f"caption{i+1}") != caption:
return False
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
if any(key is None for key in compatible_keys):
return False
except Exception as e:
logger.error(f"Error loading file: {cache_path}")
raise e
return True
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
self,
tokenize_strategy: TokenizeStrategy,
models: list[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
raise NotImplementedError
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
tensor_dict = {}
overwrite = False
if os.path.exists(cache_path):
# load existing safetensors and update it
overwrite = True
with utils.MemoryEfficientSafeOpen(cache_path) as f:
metadata = f.metadata()
keys = f.keys()
for key in keys:
tensor_dict[key] = f.get_tensor(key)
assert metadata["architecture"] == self.architecture
file_version = metadata.get("format_version", "0.0.0")
major, minor, patch = map(int, file_version.split("."))
if major > 1 or (major == 1 and minor > 0):
self.save_version_warning_printed = True
logger.warning(
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
)
else:
metadata = {}
metadata["architecture"] = self.architecture
metadata["format_version"] = "1.0.0"
metadata[f"caption{caption_index+1}"] = caption
for key, output in zip(keys, outputs):
dtype = output.dtype # long or one of float
key_suffix = self.get_key_suffix(caption_index, dtype)
tensor_dict[key + key_suffix] = output
# remove lower precision latents if higher precision latents are already cached
if overwrite:
suffix_without_dtype = self.get_key_suffix(caption_index, None)
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)
class LatentsCachingStrategy:
# TODO commonize utillity functions to this class, such as npz handling etc.
_strategy = None # strategy instance: actual strategy class
def __init__(
@@ -495,36 +692,22 @@ class LatentsCachingStrategy:
def get_compatible_latents_keys(
self,
keys: set[str],
dtype: Union[str, torch.dtype],
dtype: Optional[Union[str, torch.dtype]],
flip_aug: bool,
bucket_reso: Optional[Tuple[int, int]] = None,
latents_size: Optional[Tuple[int, int]] = None,
) -> Tuple[Optional[str], Optional[str]]:
) -> list[Optional[str], Optional[str]]:
"""
bucket_reso is (W, H), latents_size is (H, W)
"""
latents_key = None
flipped_latents_key = None
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
keys_without_dtype = ["latents" + key_suffix]
if flip_aug:
keys_without_dtype.append("latents_flipped" + key_suffix)
compatible_dtypes = get_compatible_dtypes(dtype)
for compat_dtype in compatible_dtypes:
key_suffix = self.get_key_suffix(bucket_reso, latents_size, compat_dtype)
if latents_key is None:
latents_key = "latents" + key_suffix
if latents_key not in keys:
latents_key = None
if flip_aug and flipped_latents_key is None:
flipped_latents_key = "latents_flipped" + key_suffix
if flipped_latents_key not in keys:
flipped_latents_key = None
if latents_key is not None and (flipped_latents_key is not None or not flip_aug):
break
return latents_key, flipped_latents_key
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
return compatible_keys if flip_aug else compatible_keys[0] + [None]
def _default_is_disk_cached_latents_expected(
self,
@@ -555,24 +738,13 @@ class LatentsCachingStrategy:
# print(f"alpha_mask not found: {latents_cache_path}")
return False
if preferred_dtype is None:
# remove dtype suffix from keys, because any dtype is acceptable
keys = [key.rsplit("_", 1)[0] for key in keys if not key.endswith(key_suffix_without_dtype)]
keys = set(keys)
if "latents" + key_suffix_without_dtype not in keys:
# print(f"No preferred: latents {key_suffix_without_dtype} not found: {latents_cache_path}")
return False
if flip_aug and "latents_flipped" + key_suffix_without_dtype not in keys:
# print(f"No preferred: latents_flipped {key_suffix_without_dtype} not found: {latents_cache_path}")
return False
else:
# specific dtype or compatible dtype is required
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
)
if latents_key is None or (flip_aug and flipped_latents_key is None):
# print(f"Precise dtype not found: {latents_cache_path}")
return False
# preferred_dtype is None if any dtype is acceptable
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
)
if latents_key is None or (flip_aug and flipped_latents_key is None):
# print(f"Precise dtype not found: {latents_cache_path}")
return False
except Exception as e:
logger.error(f"Error loading file: {latents_cache_path}")
raise e
@@ -581,7 +753,14 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List[utils.ImageInfo],
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
@@ -711,22 +890,7 @@ class LatentsCachingStrategy:
# remove lower precision latents if higher precision latents are already cached
if overwrite:
available_dtypes = get_available_dtypes()
available_itemsize = None
available_itemsize_flipped = None
for dtype in available_dtypes:
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
if "latents" + key_suffix in tensor_dict:
if available_itemsize is None:
available_itemsize = dtype.itemsize
elif available_itemsize > dtype.itemsize:
# if higher precision latents are already cached, remove lower precision latents
del tensor_dict["latents" + key_suffix]
if "latents_flipped" + key_suffix in tensor_dict:
if available_itemsize_flipped is None:
available_itemsize_flipped = dtype.itemsize
elif available_itemsize_flipped > dtype.itemsize:
del tensor_dict["latents_flipped" + key_suffix]
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
save_file(tensor_dict, cache_path, metadata=metadata)

View File

@@ -5,9 +5,6 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast
from library import flux_utils, train_util
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -15,6 +12,8 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_utils, train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
KEYS = ["l_pooled", "t5_out", "txt_ids"]
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
max_token_length: int,
masked: bool,
is_partial: bool = False,
apply_t5_attn_mask: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_t5_attn_mask = apply_t5_attn_mask
super().__init__(
FluxLatentsCachingStrategy.ARCHITECTURE,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked,
is_partial,
)
self.warn_fp8_weights = False
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
):
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "l_pooled" not in npz:
return False
if "t5_out" not in npz:
return False
if "txt_ids" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
l_pooled = data["l_pooled"]
t5_out = data["t5_out"]
txt_ids = data["txt_ids"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
)[0]
else:
t5_attn_mask = None
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
if not self.warn_fp8_weights:
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
@@ -154,44 +145,38 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
self.warn_fp8_weights = True
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
if l_pooled.dtype == torch.bfloat16:
l_pooled = l_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
if txt_ids.dtype == torch.bfloat16:
txt_ids = txt_ids.float()
l_pooled = l_pooled.cpu()
t5_out = t5_out.cpu()
txt_ids = txt_ids.cpu()
t5_attn_mask = tokens_and_masks[2].cpu()
l_pooled = l_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
txt_ids = txt_ids.cpu().numpy()
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, info in enumerate(infos):
for i, (info, caption_index, caption) in enumerate(batch):
l_pooled_i = l_pooled[i]
t5_out_i = t5_out[i]
txt_ids_i = txt_ids[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_t5_attn_mask_i = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
l_pooled=l_pooled_i,
t5_out=t5_out_i,
txt_ids=txt_ids_i,
t5_attn_mask=t5_attn_mask_i,
apply_t5_attn_mask=apply_t5_attn_mask_i,
)
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
if self.masked:
outputs += [t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
@@ -215,8 +200,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype

View File

@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import torch
from transformers import CLIPTokenizer
from library import train_util
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
from library.utils import setup_logging
setup_logging()
@@ -13,6 +11,8 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
TOKENIZER_ID = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
@@ -157,8 +157,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
vae_device = vae.device
vae_dtype = vae.dtype

View File

@@ -6,10 +6,6 @@ import torch
import numpy as np
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
from library import sd3_utils, train_util
from library import sd3_models
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
setup_logging()
@@ -17,6 +13,9 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util, utils
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -254,7 +253,8 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
KEYS = ["lg_out", "t5_out", "lg_pooled"]
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
def __init__(
self,
@@ -262,70 +262,51 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
apply_lg_attn_mask: bool = False,
apply_t5_attn_mask: bool = False,
max_token_length: int = 256,
masked: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
self.apply_lg_attn_mask = apply_lg_attn_mask
self.apply_t5_attn_mask = apply_t5_attn_mask
"""
apply_lg_attn_mask and apply_t5_attn_mask must be same
"""
super().__init__(
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
max_token_length,
masked=masked,
is_partial=is_partial,
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "lg_out" not in npz:
return False
if "lg_pooled" not in npz:
return False
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
return False
if "apply_lg_attn_mask" not in npz:
return False
if "t5_out" not in npz:
return False
if "t5_attn_mask" not in npz:
return False
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
return False
if "apply_t5_attn_mask" not in npz:
return False
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
lg_out = data["lg_out"]
lg_pooled = data["lg_pooled"]
t5_out = data["t5_out"]
l_attn_mask = data["clip_l_attn_mask"]
g_attn_mask = data["clip_g_attn_mask"]
t5_attn_mask = data["t5_attn_mask"]
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
)
if self.masked:
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
)
else:
l_attn_mask = g_attn_mask = t5_attn_mask = None
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
tokens_and_masks = tokenize_strategy.tokenize(captions)
with torch.no_grad():
@@ -334,51 +315,47 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy,
models,
tokens_and_masks,
apply_lg_attn_mask=self.apply_lg_attn_mask,
apply_t5_attn_mask=self.apply_t5_attn_mask,
apply_lg_attn_mask=self.masked,
apply_t5_attn_mask=self.masked,
enable_dropout=False,
)
if lg_out.dtype == torch.bfloat16:
lg_out = lg_out.float()
if lg_pooled.dtype == torch.bfloat16:
lg_pooled = lg_pooled.float()
if t5_out.dtype == torch.bfloat16:
t5_out = t5_out.float()
lg_out = lg_out.cpu()
lg_pooled = lg_pooled.cpu()
t5_out = t5_out.cpu()
lg_out = lg_out.cpu().numpy()
lg_pooled = lg_pooled.cpu().numpy()
t5_out = t5_out.cpu().numpy()
l_attn_mask = tokens_and_masks[3].cpu()
g_attn_mask = tokens_and_masks[4].cpu()
t5_attn_mask = tokens_and_masks[5].cpu()
l_attn_mask = tokens_and_masks[3].cpu().numpy()
g_attn_mask = tokens_and_masks[4].cpu().numpy()
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
for i, info in enumerate(infos):
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
if self.masked:
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
for i, (info, caption_index, caption) in enumerate(batch):
lg_out_i = lg_out[i]
t5_out_i = t5_out[i]
lg_pooled_i = lg_pooled[i]
l_attn_mask_i = l_attn_mask[i]
g_attn_mask_i = g_attn_mask[i]
t5_attn_mask_i = t5_attn_mask[i]
apply_lg_attn_mask = self.apply_lg_attn_mask
apply_t5_attn_mask = self.apply_t5_attn_mask
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
lg_out=lg_out_i,
lg_pooled=lg_pooled_i,
t5_out=t5_out_i,
clip_l_attn_mask=l_attn_mask_i,
clip_g_attn_mask=g_attn_mask_i,
t5_attn_mask=t5_attn_mask_i,
apply_lg_attn_mask=apply_lg_attn_mask,
apply_t5_attn_mask=apply_t5_attn_mask,
)
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
if self.masked:
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
else:
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [
lg_out_i,
t5_out_i,
lg_pooled_i,
l_attn_mask_i,
g_attn_mask_i,
t5_attn_mask_i,
]
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
@@ -402,8 +379,7 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
return self._default_load_latents_from_disk(cache_path, bucket_reso)
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
vae_device = vae.device
vae_dtype = vae.dtype

View File

@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library.utils import setup_logging
@@ -14,6 +12,8 @@ import logging
logger = logging.getLogger(__name__)
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
from library import utils
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
@@ -21,6 +21,9 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
class SdxlTokenizeStrategy(TokenizeStrategy):
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
"""
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
@@ -220,51 +223,51 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
ARCHITECTURE_SDXL = "sdxl"
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
def __init__(
self,
cache_to_disk: bool,
batch_size: int,
batch_size: Optional[int],
skip_disk_cache_validity_check: bool,
max_token_length: Optional[int] = None,
is_partial: bool = False,
is_weighted: bool = False,
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
"""
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
"""
max_token_length = max_token_length or 75
super().__init__(
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
is_weighted,
max_token_length=max_token_length,
)
def get_outputs_npz_path(self, image_abs_path: str) -> str:
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
def is_disk_cached_outputs_expected(
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
) -> bool:
# SDXL does not support attn mask
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
def is_disk_cached_outputs_expected(self, npz_path: str):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
try:
npz = np.load(npz_path)
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
data = np.load(npz_path)
hidden_state1 = data["hidden_state1"]
hidden_state2 = data["hidden_state2"]
pool2 = data["pool2"]
return [hidden_state1, hidden_state2, pool2]
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
def cache_batch_outputs(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: list[tuple[utils.ImageInfo, int, str]],
):
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
captions = [info.caption for info in infos]
captions = [caption for _, _, caption in batch]
if self.is_weighted:
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
@@ -279,28 +282,24 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
tokenize_strategy, models, [tokens1, tokens2]
)
if hidden_state1.dtype == torch.bfloat16:
hidden_state1 = hidden_state1.float()
if hidden_state2.dtype == torch.bfloat16:
hidden_state2 = hidden_state2.float()
if pool2.dtype == torch.bfloat16:
pool2 = pool2.float()
hidden_state1 = hidden_state1.cpu()
hidden_state2 = hidden_state2.cpu()
pool2 = pool2.cpu()
hidden_state1 = hidden_state1.cpu().numpy()
hidden_state2 = hidden_state2.cpu().numpy()
pool2 = pool2.cpu().numpy()
for i, info in enumerate(infos):
for i, (info, caption_index, caption) in enumerate(batch):
hidden_state1_i = hidden_state1[i]
hidden_state2_i = hidden_state2[i]
pool2_i = pool2[i]
if self.cache_to_disk:
np.savez(
info.text_encoder_outputs_npz,
hidden_state1=hidden_state1_i,
hidden_state2=hidden_state2_i,
pool2=pool2_i,
self.save_outputs_to_disk(
info.text_encoder_outputs_cache_path,
caption_index,
caption,
SdxlTextEncoderOutputsCachingStrategy.KEYS,
[hidden_state1_i, hidden_state2_i, pool2_i],
)
else:
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
while len(info.text_encoder_outputs) <= caption_index:
info.text_encoder_outputs.append(None)
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]

View File

@@ -83,7 +83,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, pil_resize, ImageInfo
setup_logging()
import logging
@@ -146,36 +146,6 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.caption: str = caption
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_cache_path: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
# crop left top right bottom in original pixel size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
if max_size is not None:
@@ -751,116 +721,111 @@ class BaseDataset(torch.utils.data.Dataset):
def add_replacement(self, str_from, str_to):
self.replacements[str_from] = str_to
def process_caption(self, subset: BaseSubset, caption):
# caption に prefix/suffix を付ける
if subset.caption_prefix:
caption = subset.caption_prefix + " " + caption
if subset.caption_suffix:
caption = caption + " " + subset.caption_suffix
# dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
def process_caption(self, subset: BaseSubset, caption: str, tags: Optional[str]) -> str:
# drop out caption
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
is_drop_out = (
is_drop_out
or subset.caption_dropout_every_n_epochs > 0
and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
)
if is_drop_out:
caption = ""
else:
# process wildcards
if subset.enable_wildcard:
# if caption is multiline, random choice one line
if "\n" in caption:
caption = random.choice(caption.split("\n"))
return ""
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
replacer2 = ""
while replacer1 in caption or replacer2 in caption:
replacer1 += ""
replacer2 += ""
# add prefix and suffix for caption
# DreamBooth: treated as tags, FineTuning: treated as caption, tags are processed separately
if subset.caption_prefix:
caption = subset.caption_prefix + " " + caption
if subset.caption_suffix:
caption = caption + " " + subset.caption_suffix
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
# shuffle tags
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
if tags is None and caption is not None: # DreamBooth method
tags = caption
caption = ""
# replace the wildcard
def replace_wildcard(match):
return random.choice(match.group(1).split("|"))
fixed_tokens = []
flex_tokens = []
fixed_suffix_tokens = []
if hasattr(subset, "keep_tokens_separator") and subset.keep_tokens_separator and subset.keep_tokens_separator in tags:
fixed_part, flex_part = tags.split(subset.keep_tokens_separator, 1)
if subset.keep_tokens_separator in flex_part:
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
# if caption is multiline, use the first line
caption = caption.split("\n")[0]
tokens = [t.strip() for t in tags.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
fixed_tokens = []
flex_tokens = []
fixed_suffix_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
if subset.keep_tokens_separator in flex_part:
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
flex_tokens = flex_tokens[:tokens_len]
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= subset.caption_tag_dropout_rate:
l.append(token)
return l
if subset.shuffle_caption:
random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens)
tags = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
if tags is not None:
caption = caption + " " + tags
# process wildcards
if subset.enable_wildcard:
# wildcard is like '{aaa|bbb|ccc...}'
# escape the curly braces like {{ or }}
replacer1 = ""
replacer2 = ""
while replacer1 in caption or replacer2 in caption:
replacer1 += ""
replacer2 += ""
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
# replace the wildcard
def replace_wildcard(match):
return random.choice(match.group(1).split("|"))
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
# unescape the curly braces
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
# process secondary separator
if subset.secondary_separator:
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor(
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
)
+ subset.token_warmup_min
)
flex_tokens = flex_tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= subset.caption_tag_dropout_rate:
l.append(token)
return l
if subset.shuffle_caption:
random.shuffle(flex_tokens)
flex_tokens = dropout_tags(flex_tokens)
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
# process secondary separator
if subset.secondary_separator:
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption
@@ -1171,24 +1136,28 @@ class BaseDataset(torch.utils.data.Dataset):
for i, info in enumerate(tqdm(image_infos)):
# check disk cache exists and size of text encoder outputs
if caching_strategy.cache_to_disk:
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
cache_path = caching_strategy.get_cache_path(info.absolute_path)
info.text_encoder_outputs_cache_path = cache_path # set npz filename regardless of cache availability
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different text encoder outputs
if i % num_processes != process_index:
continue
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
cache_available = caching_strategy.is_disk_cached_outputs_expected(cache_path)
if cache_available: # do not add to batch
continue
batch.append(info)
for j, caption in enumerate(info.captions):
# do not recommend to use tags when caching text encoder outputs
if info.list_of_tags is not None and len(info.list_of_tags) > 0:
caption = caption + " " + info.list_of_tags[j % len(info.list_of_tags)]
batch.append((info, j, caption))
# if number of data in batch is enough, flush the batch
if len(batch) >= batch_size:
batches.append(batch)
batch = []
while len(batch) >= batch_size:
batches.append(batch[:batch_size])
batch = batch[batch_size:]
if len(batch) > 0:
batches.append(batch)
@@ -1413,54 +1382,43 @@ class BaseDataset(torch.utils.data.Dataset):
flippeds.append(flipped)
# captionとtext encoder outputを処理する
caption = image_info.caption # default
tokenization_required = (
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
)
text_encoder_outputs = None
input_ids = None
caption = ""
if image_info.text_encoder_outputs is not None:
# cached
# cached on memory
text_encoder_outputs = image_info.text_encoder_outputs
elif image_info.text_encoder_outputs_npz is not None:
if len(text_encoder_outputs) == 1:
text_encoder_outputs = text_encoder_outputs[0]
else:
text_encoder_outputs = random.choices(text_encoder_outputs, weights=image_info.caption_weights)[0]
elif image_info.text_encoder_outputs_cache_path is not None:
# on disk
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
image_info.text_encoder_outputs_npz
index = 0
if len(image_info.captions) > 1:
index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_from_disk(
image_info.text_encoder_outputs_cache_path, index
)
else:
tokenization_required = True
text_encoder_outputs_list.append(text_encoder_outputs)
if tokenization_required:
caption = self.process_caption(subset, image_info.caption)
caption = ""
tags = None # None if no tags in dataset metadata or Dreambooth method is used
if image_info.captions is not None and len(image_info.captions) > 0:
# captions_weights may be None
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
if image_info.list_of_tags is not None and len(image_info.list_of_tags) > 0:
tags = random.choices(image_info.list_of_tags, weights=image_info.tags_weights)[0]
caption = self.process_caption(subset, caption, tags)
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
# if self.XTI_layers:
# caption_layer = []
# for layer in self.XTI_layers:
# token_strings_from = " ".join(self.token_strings)
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
# caption_ = caption.replace(token_strings_from, token_strings_to)
# caption_layer.append(caption_)
# captions.append(caption_layer)
# else:
# captions.append(caption)
# if not self.token_padding_disabled: # this option might be omitted in future
# # TODO get_input_ids must support SD3
# if self.XTI_layers:
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
# else:
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
# input_ids_list.append(token_caption)
# if len(self.tokenizers) > 1:
# if self.XTI_layers:
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
# else:
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
# input_ids2_list.append(token_caption2)
input_ids_list.append(input_ids)
captions.append(caption)
@@ -1798,7 +1756,8 @@ class DreamBoothDataset(BaseDataset):
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
captions = caption.split("\n") # empty line is allowed
info = ImageInfo(img_path, subset.num_repeats, captions, subset.is_reg, img_path)
if size is not None:
info.image_size = size
if subset.is_reg:

View File

@@ -21,6 +21,41 @@ def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
class ImageInfo:
def __init__(
self, image_key: str, num_repeats: int, captions: Optional[Union[str, list[str]]], is_reg: bool, absolute_path: str
) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.captions: Optional[list[str]] = None if captions is None else ([captions] if isinstance(captions, str) else captions)
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
self.list_of_tags: Optional[list[str]] = None
self.tags_weights: Optional[list[float]] = None
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
self.latents: Optional[torch.Tensor] = None
self.latents_flipped: Optional[torch.Tensor] = None
self.latents_cache_path: Optional[str] = None # set in cache_latents
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
# crop left top right bottom in original pixel size, not latents size
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
self.cond_img_path: Optional[str] = None
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
# new
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
# old
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
# region Logging

View File

@@ -75,6 +75,12 @@ def train(args):
)
args.cache_text_encoder_outputs = True
if args.cache_text_encoder_outputs:
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
+ " / text encoderの学習時はtext encoderの出力はキャッシュできませんt5xxlのみキャッシュすることは可能です"
@@ -168,8 +174,8 @@ def train(args):
args.text_encoder_batch_size,
False,
False,
False,
False,
args.t5xxl_max_token_length,
args.apply_lg_attn_mask,
)
)
train_dataset_group.set_current_strategies()
@@ -278,8 +284,8 @@ def train(args):
args.text_encoder_batch_size,
args.skip_cache_check,
train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial
args.t5xxl_max_token_length,
args.apply_lg_attn_mask,
args.apply_t5_attn_mask,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)

View File

@@ -43,6 +43,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
)
# prepare CLIP-L/CLIP-G/T5XXL training flags
self.train_clip = not args.network_train_unet_only
@@ -183,8 +187,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_clip or self.train_t5xxl,
max_token_length=args.t5xxl_max_token_length,
apply_lg_attn_mask=args.apply_lg_attn_mask,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None

View File

@@ -321,7 +321,11 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -223,7 +223,11 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -195,7 +195,11 @@ def train(args):
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, False
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_token_length,
is_weighted=args.weighted_captions,
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)

View File

@@ -81,7 +81,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
args.cache_text_encoder_outputs_to_disk,
None,
args.skip_cache_check,
args.max_tolen_length,
is_weighted=args.weighted_captions,
)
else:
return None