mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Text Encoder cache (WIP)
This commit is contained in:
@@ -151,15 +151,20 @@ def train(args):
|
||||
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, args.skip_cache_check, False
|
||||
)
|
||||
)
|
||||
t5xxl_max_token_length = (
|
||||
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
|
||||
)
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
False,
|
||||
)
|
||||
)
|
||||
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
|
||||
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -236,7 +241,12 @@ def train(args):
|
||||
t5xxl.to(accelerator.device)
|
||||
|
||||
text_encoder_caching_strategy = strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False, args.apply_t5_attn_mask
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
False,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -19,6 +17,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
||||
import train_network
|
||||
|
||||
|
||||
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
@@ -174,13 +175,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
|
||||
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -10,10 +10,6 @@ import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library import utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -21,6 +17,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import utils
|
||||
|
||||
|
||||
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
|
||||
if dtype is None:
|
||||
@@ -43,6 +41,58 @@ def get_available_dtypes() -> List[torch.dtype]:
|
||||
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
|
||||
|
||||
|
||||
def remove_lower_precision_values(tensor_dict: Dict[str, torch.Tensor], keys_without_dtype: list[str]) -> None:
|
||||
"""
|
||||
Removes lower precision values from tensor_dict.
|
||||
"""
|
||||
available_dtypes = get_available_dtypes()
|
||||
available_dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dtype)}" for dtype in available_dtypes]
|
||||
|
||||
for key_without_dtype in keys_without_dtype:
|
||||
available_itemsize = None
|
||||
for dtype, dtype_suffix in zip(available_dtypes, available_dtype_suffixes):
|
||||
key = key_without_dtype + dtype_suffix
|
||||
|
||||
if key in tensor_dict:
|
||||
if available_itemsize is None:
|
||||
available_itemsize = dtype.itemsize
|
||||
elif available_itemsize > dtype.itemsize:
|
||||
# if higher precision latents are already cached, remove lower precision latents
|
||||
del tensor_dict[key]
|
||||
|
||||
|
||||
def get_compatible_dtype_keys(
|
||||
dict_keys: set[str], keys_without_dtype: list[str], dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> list[Optional[str]]:
|
||||
"""
|
||||
Returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
|
||||
If the key is not found, it returns None.
|
||||
If the key in dict_keys doesn't have dtype suffix, it is acceptable, because it it long tensor.
|
||||
|
||||
:param dict_keys: set of keys in the dictionary
|
||||
:param keys_without_dtype: list of keys without dtype suffix to check
|
||||
:param dtype: dtype to check, or None for any dtype
|
||||
:return: list of keys with the specified dtype or higher precision dtype. If the key is not found, it returns None for that key.
|
||||
"""
|
||||
compatible_dtypes = get_compatible_dtypes(dtype)
|
||||
dtype_suffixes = [f"_{utils.dtype_to_normalized_str(dt)}" for dt in compatible_dtypes]
|
||||
|
||||
available_keys = []
|
||||
for key_without_dtype in keys_without_dtype:
|
||||
available_key = None
|
||||
if key_without_dtype in dict_keys:
|
||||
available_key = key_without_dtype
|
||||
else:
|
||||
for dtype_suffix in dtype_suffixes:
|
||||
key = key_without_dtype + dtype_suffix
|
||||
if key in dict_keys:
|
||||
available_key = key
|
||||
break
|
||||
available_keys.append(available_key)
|
||||
|
||||
return available_keys
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@@ -347,17 +397,26 @@ class TextEncoderOutputsCachingStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
architecture: str,
|
||||
cache_to_disk: bool,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool = False,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
max_token_length: maximum token length for the model. Including/excluding starting and ending tokens depends on the model.
|
||||
"""
|
||||
self._architecture = architecture
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
self._max_token_length = max_token_length
|
||||
self._masked = masked
|
||||
self._is_partial = is_partial
|
||||
self._is_weighted = is_weighted
|
||||
self._is_weighted = is_weighted # enable weighting by `()` or `[]` in the prompt
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
@@ -369,6 +428,18 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def get_strategy(cls) -> Optional["TextEncoderOutputsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def architecture(self):
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def max_token_length(self):
|
||||
return self._max_token_length
|
||||
|
||||
@property
|
||||
def masked(self):
|
||||
return self._masked
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -377,6 +448,11 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
suffix_masked = "_m" if self.masked else ""
|
||||
return f"_{self.architecture.lower()}_{self.max_token_length}{suffix_masked}_te.safetensors"
|
||||
|
||||
@property
|
||||
def is_partial(self):
|
||||
return self._is_partial
|
||||
@@ -385,24 +461,145 @@ class TextEncoderOutputsCachingStrategy:
|
||||
def is_weighted(self):
|
||||
return self._is_weighted
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
def get_cache_path(self, absolute_path: str) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + self.cache_suffix
|
||||
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
def load_from_disk_for_keys(self, cache_path: str, caption_index: int, base_keys: list[str]) -> list[Optional[torch.Tensor]]:
|
||||
"""
|
||||
get tensors for keys_without_dtype, without dtype suffix. if the key is not found, it returns None.
|
||||
all dtype tensors are returned, because cache validation is done in advance.
|
||||
"""
|
||||
with safe_open(cache_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, version.split("."))
|
||||
if major > 1: # or (major == 1 and minor > 0):
|
||||
if not self.load_version_warning_printed:
|
||||
self.load_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
|
||||
dict_keys = f.keys()
|
||||
results = []
|
||||
compatible_keys = self.get_compatible_output_keys(dict_keys, caption_index, base_keys, None)
|
||||
for key in compatible_keys:
|
||||
results.append(f.get_tensor(key) if key is not None else None)
|
||||
|
||||
return results
|
||||
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
|
||||
raise NotImplementedError
|
||||
def get_key_suffix(self, prompt_id: int, dtype: Optional[Union[str, torch.dtype]] = None) -> str:
|
||||
"""
|
||||
masked: may be False even if self.masked is True. It is False for some outputs.
|
||||
"""
|
||||
key_suffix = f"_{prompt_id}"
|
||||
if dtype is not None and dtype.is_floating_point: # float tensor only
|
||||
key_suffix += "_" + utils.dtype_to_normalized_str(dtype)
|
||||
return key_suffix
|
||||
|
||||
def get_compatible_output_keys(
|
||||
self, dict_keys: set[str], caption_index: int, base_keys: list[str], dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> list[Optional[str], Optional[str]]:
|
||||
"""
|
||||
returns the list of keys with the specified dtype or higher precision dtype. If the specified dtype is None, any dtype is acceptable.
|
||||
"""
|
||||
key_suffix = self.get_key_suffix(caption_index, None)
|
||||
keys_without_dtype = [k + key_suffix for k in base_keys]
|
||||
return get_compatible_dtype_keys(dict_keys, keys_without_dtype, dtype)
|
||||
|
||||
def _default_is_disk_cached_outputs_expected(
|
||||
self,
|
||||
cache_path: str,
|
||||
captions: list[str],
|
||||
base_keys: list[tuple[str, bool]],
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(cache_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
keys = f.keys()
|
||||
metadata = f.metadata()
|
||||
|
||||
# check captions in metadata
|
||||
for i, caption in enumerate(captions):
|
||||
if metadata.get(f"caption{i+1}") != caption:
|
||||
return False
|
||||
|
||||
compatible_keys = self.get_compatible_output_keys(keys, i, base_keys, preferred_dtype)
|
||||
if any(key is None for key in compatible_keys):
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {cache_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, batch: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: list[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def save_outputs_to_disk(self, cache_path: str, caption_index: int, caption: str, keys: list[str], outputs: list[torch.Tensor]):
|
||||
tensor_dict = {}
|
||||
|
||||
overwrite = False
|
||||
if os.path.exists(cache_path):
|
||||
# load existing safetensors and update it
|
||||
overwrite = True
|
||||
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
metadata = f.metadata()
|
||||
keys = f.keys()
|
||||
for key in keys:
|
||||
tensor_dict[key] = f.get_tensor(key)
|
||||
assert metadata["architecture"] == self.architecture
|
||||
|
||||
file_version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, file_version.split("."))
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
self.save_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
else:
|
||||
metadata = {}
|
||||
metadata["architecture"] = self.architecture
|
||||
metadata["format_version"] = "1.0.0"
|
||||
|
||||
metadata[f"caption{caption_index+1}"] = caption
|
||||
|
||||
for key, output in zip(keys, outputs):
|
||||
dtype = output.dtype # long or one of float
|
||||
key_suffix = self.get_key_suffix(caption_index, dtype)
|
||||
tensor_dict[key + key_suffix] = output
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
suffix_without_dtype = self.get_key_suffix(caption_index, None)
|
||||
remove_lower_precision_values(tensor_dict, [key + suffix_without_dtype])
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
|
||||
class LatentsCachingStrategy:
|
||||
# TODO commonize utillity functions to this class, such as npz handling etc.
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(
|
||||
@@ -495,36 +692,22 @@ class LatentsCachingStrategy:
|
||||
def get_compatible_latents_keys(
|
||||
self,
|
||||
keys: set[str],
|
||||
dtype: Union[str, torch.dtype],
|
||||
dtype: Optional[Union[str, torch.dtype]],
|
||||
flip_aug: bool,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
) -> list[Optional[str], Optional[str]]:
|
||||
"""
|
||||
bucket_reso is (W, H), latents_size is (H, W)
|
||||
"""
|
||||
|
||||
latents_key = None
|
||||
flipped_latents_key = None
|
||||
key_suffix = self.get_key_suffix(bucket_reso, latents_size, None)
|
||||
keys_without_dtype = ["latents" + key_suffix]
|
||||
if flip_aug:
|
||||
keys_without_dtype.append("latents_flipped" + key_suffix)
|
||||
|
||||
compatible_dtypes = get_compatible_dtypes(dtype)
|
||||
|
||||
for compat_dtype in compatible_dtypes:
|
||||
key_suffix = self.get_key_suffix(bucket_reso, latents_size, compat_dtype)
|
||||
|
||||
if latents_key is None:
|
||||
latents_key = "latents" + key_suffix
|
||||
if latents_key not in keys:
|
||||
latents_key = None
|
||||
if flip_aug and flipped_latents_key is None:
|
||||
flipped_latents_key = "latents_flipped" + key_suffix
|
||||
if flipped_latents_key not in keys:
|
||||
flipped_latents_key = None
|
||||
|
||||
if latents_key is not None and (flipped_latents_key is not None or not flip_aug):
|
||||
break
|
||||
|
||||
return latents_key, flipped_latents_key
|
||||
compatible_keys = get_compatible_dtype_keys(keys, keys_without_dtype, dtype)
|
||||
return compatible_keys if flip_aug else compatible_keys[0] + [None]
|
||||
|
||||
def _default_is_disk_cached_latents_expected(
|
||||
self,
|
||||
@@ -555,24 +738,13 @@ class LatentsCachingStrategy:
|
||||
# print(f"alpha_mask not found: {latents_cache_path}")
|
||||
return False
|
||||
|
||||
if preferred_dtype is None:
|
||||
# remove dtype suffix from keys, because any dtype is acceptable
|
||||
keys = [key.rsplit("_", 1)[0] for key in keys if not key.endswith(key_suffix_without_dtype)]
|
||||
keys = set(keys)
|
||||
if "latents" + key_suffix_without_dtype not in keys:
|
||||
# print(f"No preferred: latents {key_suffix_without_dtype} not found: {latents_cache_path}")
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_suffix_without_dtype not in keys:
|
||||
# print(f"No preferred: latents_flipped {key_suffix_without_dtype} not found: {latents_cache_path}")
|
||||
return False
|
||||
else:
|
||||
# specific dtype or compatible dtype is required
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
|
||||
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
|
||||
)
|
||||
if latents_key is None or (flip_aug and flipped_latents_key is None):
|
||||
# print(f"Precise dtype not found: {latents_cache_path}")
|
||||
return False
|
||||
# preferred_dtype is None if any dtype is acceptable
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
|
||||
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
|
||||
)
|
||||
if latents_key is None or (flip_aug and flipped_latents_key is None):
|
||||
# print(f"Precise dtype not found: {latents_cache_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {latents_cache_path}")
|
||||
raise e
|
||||
@@ -581,7 +753,14 @@ class LatentsCachingStrategy:
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
||||
self,
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos: List[utils.ImageInfo],
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
@@ -711,22 +890,7 @@ class LatentsCachingStrategy:
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
available_dtypes = get_available_dtypes()
|
||||
available_itemsize = None
|
||||
available_itemsize_flipped = None
|
||||
for dtype in available_dtypes:
|
||||
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
|
||||
if "latents" + key_suffix in tensor_dict:
|
||||
if available_itemsize is None:
|
||||
available_itemsize = dtype.itemsize
|
||||
elif available_itemsize > dtype.itemsize:
|
||||
# if higher precision latents are already cached, remove lower precision latents
|
||||
del tensor_dict["latents" + key_suffix]
|
||||
|
||||
if "latents_flipped" + key_suffix in tensor_dict:
|
||||
if available_itemsize_flipped is None:
|
||||
available_itemsize_flipped = dtype.itemsize
|
||||
elif available_itemsize_flipped > dtype.itemsize:
|
||||
del tensor_dict["latents_flipped" + key_suffix]
|
||||
suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
|
||||
remove_lower_precision_values(tensor_dict, ["latents" + suffix_without_dtype, "latents_flipped" + suffix_without_dtype])
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
@@ -5,9 +5,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from library import flux_utils, train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -15,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_utils, train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
T5_XXL_TOKENIZER_ID = "google/t5-v1_1-xxl"
|
||||
@@ -86,64 +85,56 @@ class FluxTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_flux_te.npz"
|
||||
KEYS = ["l_pooled", "t5_out", "txt_ids"]
|
||||
KEYS_MASKED = ["t5_attn_mask", "apply_t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: int,
|
||||
masked: bool,
|
||||
is_partial: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
super().__init__(
|
||||
FluxLatentsCachingStrategy.ARCHITECTURE,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked,
|
||||
is_partial,
|
||||
)
|
||||
|
||||
self.warn_fp8_weights = False
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + FluxTextEncoderOutputsCachingStrategy.FLUX_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
):
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "l_pooled" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "txt_ids" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
l_pooled = data["l_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
txt_ids = data["txt_ids"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
# apply_t5_attn_mask should be same as self.apply_t5_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
l_pooled, t5_out, txt_ids = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)[0]
|
||||
else:
|
||||
t5_attn_mask = None
|
||||
return [l_pooled, t5_out, txt_ids, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
if not self.warn_fp8_weights:
|
||||
if flux_utils.get_t5xxl_actual_dtype(models[1]) == torch.float8_e4m3fn:
|
||||
@@ -154,44 +145,38 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
self.warn_fp8_weights = True
|
||||
|
||||
flux_text_encoding_strategy: FluxTextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
# attn_mask is applied in text_encoding_strategy.encode_tokens if apply_t5_attn_mask is True
|
||||
l_pooled, t5_out, txt_ids, _ = flux_text_encoding_strategy.encode_tokens(tokenize_strategy, models, tokens_and_masks)
|
||||
|
||||
if l_pooled.dtype == torch.bfloat16:
|
||||
l_pooled = l_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
if txt_ids.dtype == torch.bfloat16:
|
||||
txt_ids = txt_ids.float()
|
||||
l_pooled = l_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
txt_ids = txt_ids.cpu()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu()
|
||||
|
||||
l_pooled = l_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
txt_ids = txt_ids.cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[2].cpu().numpy()
|
||||
keys = FluxTextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += FluxTextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
l_pooled_i = l_pooled[i]
|
||||
t5_out_i = t5_out[i]
|
||||
txt_ids_i = txt_ids[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_t5_attn_mask_i = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
l_pooled=l_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
txt_ids=txt_ids_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask_i,
|
||||
)
|
||||
outputs = [l_pooled_i, t5_out_i, txt_ids_i]
|
||||
if self.masked:
|
||||
outputs += [t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [l_pooled_i, t5_out_i, txt_ids_i, t5_attn_mask_i]
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -215,8 +200,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPTokenizer
|
||||
from library import train_util
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -13,6 +11,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TokenizeStrategy, TextEncodingStrategy
|
||||
|
||||
TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_ID = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う v2とv2.1はtokenizer仕様は同じ
|
||||
@@ -157,8 +157,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).latent_dist.sample()
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
@@ -6,10 +6,6 @@ import torch
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_utils, train_util
|
||||
from library import sd3_models
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -17,6 +13,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util, utils
|
||||
from library.strategy_base import LatentsCachingStrategy, TextEncodingStrategy, TokenizeStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
CLIP_L_TOKENIZER_ID = "openai/clip-vit-large-patch14"
|
||||
CLIP_G_TOKENIZER_ID = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -254,7 +253,8 @@ class Sd3TextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_sd3_te.npz"
|
||||
KEYS = ["lg_out", "t5_out", "lg_pooled"]
|
||||
KEYS_MASKED = ["clip_l_attn_mask", "clip_g_attn_mask", "t5_attn_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -262,70 +262,51 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
batch_size: int,
|
||||
skip_disk_cache_validity_check: bool,
|
||||
is_partial: bool = False,
|
||||
apply_lg_attn_mask: bool = False,
|
||||
apply_t5_attn_mask: bool = False,
|
||||
max_token_length: int = 256,
|
||||
masked: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial)
|
||||
self.apply_lg_attn_mask = apply_lg_attn_mask
|
||||
self.apply_t5_attn_mask = apply_t5_attn_mask
|
||||
"""
|
||||
apply_lg_attn_mask and apply_t5_attn_mask must be same
|
||||
"""
|
||||
super().__init__(
|
||||
Sd3LatentsCachingStrategy.ARCHITECTURE_SD3,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
max_token_length,
|
||||
masked=masked,
|
||||
is_partial=is_partial,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + Sd3TextEncoderOutputsCachingStrategy.SD3_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "lg_out" not in npz:
|
||||
return False
|
||||
if "lg_pooled" not in npz:
|
||||
return False
|
||||
if "clip_l_attn_mask" not in npz or "clip_g_attn_mask" not in npz: # necessary even if not used
|
||||
return False
|
||||
if "apply_lg_attn_mask" not in npz:
|
||||
return False
|
||||
if "t5_out" not in npz:
|
||||
return False
|
||||
if "t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_lg_attn_mask = npz["apply_lg_attn_mask"]
|
||||
if npz_apply_lg_attn_mask != self.apply_lg_attn_mask:
|
||||
return False
|
||||
if "apply_t5_attn_mask" not in npz:
|
||||
return False
|
||||
npz_apply_t5_attn_mask = npz["apply_t5_attn_mask"]
|
||||
if npz_apply_t5_attn_mask != self.apply_t5_attn_mask:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
lg_out = data["lg_out"]
|
||||
lg_pooled = data["lg_pooled"]
|
||||
t5_out = data["t5_out"]
|
||||
|
||||
l_attn_mask = data["clip_l_attn_mask"]
|
||||
g_attn_mask = data["clip_g_attn_mask"]
|
||||
t5_attn_mask = data["t5_attn_mask"]
|
||||
|
||||
# apply_t5_attn_mask and apply_lg_attn_mask are same as self.apply_t5_attn_mask and self.apply_lg_attn_mask
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
lg_out, lg_pooled, t5_out = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
)
|
||||
if self.masked:
|
||||
l_attn_mask, g_attn_mask, t5_attn_mask = self.load_from_disk_for_keys(
|
||||
cache_path, caption_index, Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
)
|
||||
else:
|
||||
l_attn_mask = g_attn_mask = t5_attn_mask = None
|
||||
return [lg_out, t5_out, lg_pooled, l_attn_mask, g_attn_mask, t5_attn_mask]
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
sd3_text_encoding_strategy: Sd3TextEncodingStrategy = text_encoding_strategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(captions)
|
||||
with torch.no_grad():
|
||||
@@ -334,51 +315,47 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy,
|
||||
models,
|
||||
tokens_and_masks,
|
||||
apply_lg_attn_mask=self.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=self.apply_t5_attn_mask,
|
||||
apply_lg_attn_mask=self.masked,
|
||||
apply_t5_attn_mask=self.masked,
|
||||
enable_dropout=False,
|
||||
)
|
||||
|
||||
if lg_out.dtype == torch.bfloat16:
|
||||
lg_out = lg_out.float()
|
||||
if lg_pooled.dtype == torch.bfloat16:
|
||||
lg_pooled = lg_pooled.float()
|
||||
if t5_out.dtype == torch.bfloat16:
|
||||
t5_out = t5_out.float()
|
||||
lg_out = lg_out.cpu()
|
||||
lg_pooled = lg_pooled.cpu()
|
||||
t5_out = t5_out.cpu()
|
||||
|
||||
lg_out = lg_out.cpu().numpy()
|
||||
lg_pooled = lg_pooled.cpu().numpy()
|
||||
t5_out = t5_out.cpu().numpy()
|
||||
l_attn_mask = tokens_and_masks[3].cpu()
|
||||
g_attn_mask = tokens_and_masks[4].cpu()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu()
|
||||
|
||||
l_attn_mask = tokens_and_masks[3].cpu().numpy()
|
||||
g_attn_mask = tokens_and_masks[4].cpu().numpy()
|
||||
t5_attn_mask = tokens_and_masks[5].cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
keys = Sd3TextEncoderOutputsCachingStrategy.KEYS
|
||||
if self.masked:
|
||||
keys += Sd3TextEncoderOutputsCachingStrategy.KEYS_MASKED
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
lg_out_i = lg_out[i]
|
||||
t5_out_i = t5_out[i]
|
||||
lg_pooled_i = lg_pooled[i]
|
||||
l_attn_mask_i = l_attn_mask[i]
|
||||
g_attn_mask_i = g_attn_mask[i]
|
||||
t5_attn_mask_i = t5_attn_mask[i]
|
||||
apply_lg_attn_mask = self.apply_lg_attn_mask
|
||||
apply_t5_attn_mask = self.apply_t5_attn_mask
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
lg_out=lg_out_i,
|
||||
lg_pooled=lg_pooled_i,
|
||||
t5_out=t5_out_i,
|
||||
clip_l_attn_mask=l_attn_mask_i,
|
||||
clip_g_attn_mask=g_attn_mask_i,
|
||||
t5_attn_mask=t5_attn_mask_i,
|
||||
apply_lg_attn_mask=apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=apply_t5_attn_mask,
|
||||
)
|
||||
outputs = [lg_out_i, t5_out_i, lg_pooled_i]
|
||||
if self.masked:
|
||||
outputs += [l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i]
|
||||
self.save_outputs_to_disk(info.text_encoder_outputs_cache_path, caption_index, caption, keys, outputs)
|
||||
else:
|
||||
# it's fine that attn mask is not None. it's overwritten before calling the model if necessary
|
||||
info.text_encoder_outputs = (lg_out_i, t5_out_i, lg_pooled_i, l_attn_mask_i, g_attn_mask_i, t5_attn_mask_i)
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [
|
||||
lg_out_i,
|
||||
t5_out_i,
|
||||
lg_pooled_i,
|
||||
l_attn_mask_i,
|
||||
g_attn_mask_i,
|
||||
t5_attn_mask_i,
|
||||
]
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
@@ -402,8 +379,7 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
def cache_batch_latents(self, vae, image_infos: List[utils.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
encode_by_vae = lambda img_tensor: vae.encode(img_tensor).to("cpu")
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
@@ -4,8 +4,6 @@ from typing import Any, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
|
||||
|
||||
from library.utils import setup_logging
|
||||
|
||||
@@ -14,6 +12,8 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library.strategy_base import TokenizeStrategy, TextEncodingStrategy, TextEncoderOutputsCachingStrategy
|
||||
from library import utils
|
||||
|
||||
TOKENIZER1_PATH = "openai/clip-vit-large-patch14"
|
||||
TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
@@ -21,6 +21,9 @@ TOKENIZER2_PATH = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"
|
||||
|
||||
class SdxlTokenizeStrategy(TokenizeStrategy):
|
||||
def __init__(self, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
max_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
|
||||
"""
|
||||
self.tokenizer1 = self._load_tokenizer(CLIPTokenizer, TOKENIZER1_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2 = self._load_tokenizer(CLIPTokenizer, TOKENIZER2_PATH, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
self.tokenizer2.pad_token_id = 0 # use 0 as pad token for tokenizer2
|
||||
@@ -220,51 +223,51 @@ class SdxlTextEncodingStrategy(TextEncodingStrategy):
|
||||
|
||||
|
||||
class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_te_outputs.npz"
|
||||
ARCHITECTURE_SDXL = "sdxl"
|
||||
KEYS = ["hidden_state1", "hidden_state2", "pool2"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_to_disk: bool,
|
||||
batch_size: int,
|
||||
batch_size: Optional[int],
|
||||
skip_disk_cache_validity_check: bool,
|
||||
max_token_length: Optional[int] = None,
|
||||
is_partial: bool = False,
|
||||
is_weighted: bool = False,
|
||||
) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check, is_partial, is_weighted)
|
||||
"""
|
||||
max_token_length: maximum length of the input text, **excluding** the special tokens. None or 150 or 225
|
||||
"""
|
||||
max_token_length = max_token_length or 75
|
||||
super().__init__(
|
||||
SdxlTextEncoderOutputsCachingStrategy.ARCHITECTURE_SDXL,
|
||||
cache_to_disk,
|
||||
batch_size,
|
||||
skip_disk_cache_validity_check,
|
||||
is_partial,
|
||||
is_weighted,
|
||||
max_token_length=max_token_length,
|
||||
)
|
||||
|
||||
def get_outputs_npz_path(self, image_abs_path: str) -> str:
|
||||
return os.path.splitext(image_abs_path)[0] + SdxlTextEncoderOutputsCachingStrategy.SDXL_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
|
||||
def is_disk_cached_outputs_expected(
|
||||
self, cache_path: str, prompts: list[str], preferred_dtype: Optional[Union[str, torch.dtype]]
|
||||
) -> bool:
|
||||
# SDXL does not support attn mask
|
||||
base_keys = SdxlTextEncoderOutputsCachingStrategy.KEYS
|
||||
return self._default_is_disk_cached_outputs_expected(cache_path, prompts, base_keys, preferred_dtype)
|
||||
|
||||
def is_disk_cached_outputs_expected(self, npz_path: str):
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "hidden_state1" not in npz or "hidden_state2" not in npz or "pool2" not in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
|
||||
data = np.load(npz_path)
|
||||
hidden_state1 = data["hidden_state1"]
|
||||
hidden_state2 = data["hidden_state2"]
|
||||
pool2 = data["pool2"]
|
||||
return [hidden_state1, hidden_state2, pool2]
|
||||
def load_from_disk(self, cache_path: str, caption_index: int) -> list[Optional[torch.Tensor]]:
|
||||
return self.load_from_disk_for_keys(cache_path, caption_index, SdxlTextEncoderOutputsCachingStrategy.KEYS)
|
||||
|
||||
def cache_batch_outputs(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], text_encoding_strategy: TextEncodingStrategy, infos: List
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
models: List[Any],
|
||||
text_encoding_strategy: TextEncodingStrategy,
|
||||
batch: list[tuple[utils.ImageInfo, int, str]],
|
||||
):
|
||||
sdxl_text_encoding_strategy = text_encoding_strategy # type: SdxlTextEncodingStrategy
|
||||
captions = [info.caption for info in infos]
|
||||
captions = [caption for _, _, caption in batch]
|
||||
|
||||
if self.is_weighted:
|
||||
tokens_list, weights_list = tokenize_strategy.tokenize_with_weights(captions)
|
||||
@@ -279,28 +282,24 @@ class SdxlTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
tokenize_strategy, models, [tokens1, tokens2]
|
||||
)
|
||||
|
||||
if hidden_state1.dtype == torch.bfloat16:
|
||||
hidden_state1 = hidden_state1.float()
|
||||
if hidden_state2.dtype == torch.bfloat16:
|
||||
hidden_state2 = hidden_state2.float()
|
||||
if pool2.dtype == torch.bfloat16:
|
||||
pool2 = pool2.float()
|
||||
hidden_state1 = hidden_state1.cpu()
|
||||
hidden_state2 = hidden_state2.cpu()
|
||||
pool2 = pool2.cpu()
|
||||
|
||||
hidden_state1 = hidden_state1.cpu().numpy()
|
||||
hidden_state2 = hidden_state2.cpu().numpy()
|
||||
pool2 = pool2.cpu().numpy()
|
||||
|
||||
for i, info in enumerate(infos):
|
||||
for i, (info, caption_index, caption) in enumerate(batch):
|
||||
hidden_state1_i = hidden_state1[i]
|
||||
hidden_state2_i = hidden_state2[i]
|
||||
pool2_i = pool2[i]
|
||||
|
||||
if self.cache_to_disk:
|
||||
np.savez(
|
||||
info.text_encoder_outputs_npz,
|
||||
hidden_state1=hidden_state1_i,
|
||||
hidden_state2=hidden_state2_i,
|
||||
pool2=pool2_i,
|
||||
self.save_outputs_to_disk(
|
||||
info.text_encoder_outputs_cache_path,
|
||||
caption_index,
|
||||
caption,
|
||||
SdxlTextEncoderOutputsCachingStrategy.KEYS,
|
||||
[hidden_state1_i, hidden_state2_i, pool2_i],
|
||||
)
|
||||
else:
|
||||
info.text_encoder_outputs = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
while len(info.text_encoder_outputs) <= caption_index:
|
||||
info.text_encoder_outputs.append(None)
|
||||
info.text_encoder_outputs[caption_index] = [hidden_state1_i, hidden_state2_i, pool2_i]
|
||||
|
||||
@@ -83,7 +83,7 @@ import library.model_util as model_util
|
||||
import library.huggingface_util as huggingface_util
|
||||
import library.sai_model_spec as sai_model_spec
|
||||
import library.deepspeed_utils as deepspeed_utils
|
||||
from library.utils import setup_logging, pil_resize
|
||||
from library.utils import setup_logging, pil_resize, ImageInfo
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -146,36 +146,6 @@ TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.caption: str = caption
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_cache_path: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
# crop left top right bottom in original pixel size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[List[torch.Tensor]] = None
|
||||
# old
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
class BucketManager:
|
||||
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
|
||||
if max_size is not None:
|
||||
@@ -751,116 +721,111 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def add_replacement(self, str_from, str_to):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, subset: BaseSubset, caption):
|
||||
# caption に prefix/suffix を付ける
|
||||
if subset.caption_prefix:
|
||||
caption = subset.caption_prefix + " " + caption
|
||||
if subset.caption_suffix:
|
||||
caption = caption + " " + subset.caption_suffix
|
||||
|
||||
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||
def process_caption(self, subset: BaseSubset, caption: str, tags: Optional[str]) -> str:
|
||||
# drop out caption
|
||||
is_drop_out = subset.caption_dropout_rate > 0 and random.random() < subset.caption_dropout_rate
|
||||
is_drop_out = (
|
||||
is_drop_out
|
||||
or subset.caption_dropout_every_n_epochs > 0
|
||||
and self.current_epoch % subset.caption_dropout_every_n_epochs == 0
|
||||
)
|
||||
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# if caption is multiline, random choice one line
|
||||
if "\n" in caption:
|
||||
caption = random.choice(caption.split("\n"))
|
||||
return ""
|
||||
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
replacer2 = "⦆"
|
||||
while replacer1 in caption or replacer2 in caption:
|
||||
replacer1 += "⦅"
|
||||
replacer2 += "⦆"
|
||||
# add prefix and suffix for caption
|
||||
# DreamBooth: treated as tags, FineTuning: treated as caption, tags are processed separately
|
||||
if subset.caption_prefix:
|
||||
caption = subset.caption_prefix + " " + caption
|
||||
if subset.caption_suffix:
|
||||
caption = caption + " " + subset.caption_suffix
|
||||
|
||||
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
|
||||
# shuffle tags
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
if tags is None and caption is not None: # DreamBooth method
|
||||
tags = caption
|
||||
caption = ""
|
||||
|
||||
# replace the wildcard
|
||||
def replace_wildcard(match):
|
||||
return random.choice(match.group(1).split("|"))
|
||||
fixed_tokens = []
|
||||
flex_tokens = []
|
||||
fixed_suffix_tokens = []
|
||||
if hasattr(subset, "keep_tokens_separator") and subset.keep_tokens_separator and subset.keep_tokens_separator in tags:
|
||||
fixed_part, flex_part = tags.split(subset.keep_tokens_separator, 1)
|
||||
if subset.keep_tokens_separator in flex_part:
|
||||
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
|
||||
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
|
||||
|
||||
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||
else:
|
||||
# if caption is multiline, use the first line
|
||||
caption = caption.split("\n")[0]
|
||||
tokens = [t.strip() for t in tags.strip().split(subset.caption_separator)]
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
|
||||
fixed_tokens = []
|
||||
flex_tokens = []
|
||||
fixed_suffix_tokens = []
|
||||
if (
|
||||
hasattr(subset, "keep_tokens_separator")
|
||||
and subset.keep_tokens_separator
|
||||
and subset.keep_tokens_separator in caption
|
||||
):
|
||||
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
|
||||
if subset.keep_tokens_separator in flex_part:
|
||||
flex_part, fixed_suffix_part = flex_part.split(subset.keep_tokens_separator, 1)
|
||||
fixed_suffix_tokens = [t.strip() for t in fixed_suffix_part.split(subset.caption_separator) if t.strip()]
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
|
||||
+ subset.token_warmup_min
|
||||
)
|
||||
flex_tokens = flex_tokens[:tokens_len]
|
||||
|
||||
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
|
||||
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
|
||||
def dropout_tags(tokens):
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
l = []
|
||||
for token in tokens:
|
||||
if random.random() >= subset.caption_tag_dropout_rate:
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
|
||||
flex_tokens = dropout_tags(flex_tokens)
|
||||
|
||||
tags = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
|
||||
|
||||
if tags is not None:
|
||||
caption = caption + " " + tags
|
||||
|
||||
# process wildcards
|
||||
if subset.enable_wildcard:
|
||||
# wildcard is like '{aaa|bbb|ccc...}'
|
||||
# escape the curly braces like {{ or }}
|
||||
replacer1 = "⦅"
|
||||
replacer2 = "⦆"
|
||||
while replacer1 in caption or replacer2 in caption:
|
||||
replacer1 += "⦅"
|
||||
replacer2 += "⦆"
|
||||
|
||||
caption = caption.replace("{{", replacer1).replace("}}", replacer2)
|
||||
|
||||
# replace the wildcard
|
||||
def replace_wildcard(match):
|
||||
return random.choice(match.group(1).split("|"))
|
||||
|
||||
caption = re.sub(r"\{([^}]+)\}", replace_wildcard, caption)
|
||||
|
||||
# unescape the curly braces
|
||||
caption = caption.replace(replacer1, "{").replace(replacer2, "}")
|
||||
|
||||
# process secondary separator
|
||||
if subset.secondary_separator:
|
||||
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
|
||||
flex_tokens = tokens[:]
|
||||
if subset.keep_tokens > 0:
|
||||
fixed_tokens = flex_tokens[: subset.keep_tokens]
|
||||
flex_tokens = tokens[subset.keep_tokens :]
|
||||
|
||||
if subset.token_warmup_step < 1: # 初回に上書きする
|
||||
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
|
||||
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
|
||||
tokens_len = (
|
||||
math.floor(
|
||||
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
|
||||
)
|
||||
+ subset.token_warmup_min
|
||||
)
|
||||
flex_tokens = flex_tokens[:tokens_len]
|
||||
|
||||
def dropout_tags(tokens):
|
||||
if subset.caption_tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
l = []
|
||||
for token in tokens:
|
||||
if random.random() >= subset.caption_tag_dropout_rate:
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
if subset.shuffle_caption:
|
||||
random.shuffle(flex_tokens)
|
||||
|
||||
flex_tokens = dropout_tags(flex_tokens)
|
||||
|
||||
caption = ", ".join(fixed_tokens + flex_tokens + fixed_suffix_tokens)
|
||||
|
||||
# process secondary separator
|
||||
if subset.secondary_separator:
|
||||
caption = caption.replace(subset.secondary_separator, subset.caption_separator)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
if type(str_to) == list:
|
||||
caption = random.choice(str_to)
|
||||
else:
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
caption = str_to
|
||||
else:
|
||||
caption = caption.replace(str_from, str_to)
|
||||
|
||||
return caption
|
||||
|
||||
@@ -1171,24 +1136,28 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
# check disk cache exists and size of text encoder outputs
|
||||
if caching_strategy.cache_to_disk:
|
||||
te_out_npz = caching_strategy.get_outputs_npz_path(info.absolute_path)
|
||||
info.text_encoder_outputs_npz = te_out_npz # set npz filename regardless of cache availability
|
||||
cache_path = caching_strategy.get_cache_path(info.absolute_path)
|
||||
info.text_encoder_outputs_cache_path = cache_path # set npz filename regardless of cache availability
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different text encoder outputs
|
||||
if i % num_processes != process_index:
|
||||
continue
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(te_out_npz)
|
||||
cache_available = caching_strategy.is_disk_cached_outputs_expected(cache_path)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
batch.append(info)
|
||||
for j, caption in enumerate(info.captions):
|
||||
# do not recommend to use tags when caching text encoder outputs
|
||||
if info.list_of_tags is not None and len(info.list_of_tags) > 0:
|
||||
caption = caption + " " + info.list_of_tags[j % len(info.list_of_tags)]
|
||||
batch.append((info, j, caption))
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
while len(batch) >= batch_size:
|
||||
batches.append(batch[:batch_size])
|
||||
batch = batch[batch_size:]
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
@@ -1413,54 +1382,43 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
flippeds.append(flipped)
|
||||
|
||||
# captionとtext encoder outputを処理する
|
||||
caption = image_info.caption # default
|
||||
|
||||
tokenization_required = (
|
||||
self.text_encoder_output_caching_strategy is None or self.text_encoder_output_caching_strategy.is_partial
|
||||
)
|
||||
text_encoder_outputs = None
|
||||
input_ids = None
|
||||
caption = ""
|
||||
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached
|
||||
# cached on memory
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
elif image_info.text_encoder_outputs_npz is not None:
|
||||
if len(text_encoder_outputs) == 1:
|
||||
text_encoder_outputs = text_encoder_outputs[0]
|
||||
else:
|
||||
text_encoder_outputs = random.choices(text_encoder_outputs, weights=image_info.caption_weights)[0]
|
||||
elif image_info.text_encoder_outputs_cache_path is not None:
|
||||
# on disk
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_outputs_npz(
|
||||
image_info.text_encoder_outputs_npz
|
||||
index = 0
|
||||
if len(image_info.captions) > 1:
|
||||
index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_from_disk(
|
||||
image_info.text_encoder_outputs_cache_path, index
|
||||
)
|
||||
else:
|
||||
tokenization_required = True
|
||||
text_encoder_outputs_list.append(text_encoder_outputs)
|
||||
|
||||
if tokenization_required:
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
caption = ""
|
||||
tags = None # None if no tags in dataset metadata or Dreambooth method is used
|
||||
if image_info.captions is not None and len(image_info.captions) > 0:
|
||||
# captions_weights may be None
|
||||
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
|
||||
if image_info.list_of_tags is not None and len(image_info.list_of_tags) > 0:
|
||||
tags = random.choices(image_info.list_of_tags, weights=image_info.tags_weights)[0]
|
||||
|
||||
caption = self.process_caption(subset, caption, tags)
|
||||
input_ids = [ids[0] for ids in self.tokenize_strategy.tokenize(caption)] # remove batch dimension
|
||||
# if self.XTI_layers:
|
||||
# caption_layer = []
|
||||
# for layer in self.XTI_layers:
|
||||
# token_strings_from = " ".join(self.token_strings)
|
||||
# token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
# caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
# caption_layer.append(caption_)
|
||||
# captions.append(caption_layer)
|
||||
# else:
|
||||
# captions.append(caption)
|
||||
|
||||
# if not self.token_padding_disabled: # this option might be omitted in future
|
||||
# # TODO get_input_ids must support SD3
|
||||
# if self.XTI_layers:
|
||||
# token_caption = self.get_input_ids(caption_layer, self.tokenizers[0])
|
||||
# else:
|
||||
# token_caption = self.get_input_ids(caption, self.tokenizers[0])
|
||||
# input_ids_list.append(token_caption)
|
||||
|
||||
# if len(self.tokenizers) > 1:
|
||||
# if self.XTI_layers:
|
||||
# token_caption2 = self.get_input_ids(caption_layer, self.tokenizers[1])
|
||||
# else:
|
||||
# token_caption2 = self.get_input_ids(caption, self.tokenizers[1])
|
||||
# input_ids2_list.append(token_caption2)
|
||||
|
||||
input_ids_list.append(input_ids)
|
||||
captions.append(caption)
|
||||
@@ -1798,7 +1756,8 @@ class DreamBoothDataset(BaseDataset):
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
info = ImageInfo(img_path, subset.num_repeats, caption, subset.is_reg, img_path)
|
||||
captions = caption.split("\n") # empty line is allowed
|
||||
info = ImageInfo(img_path, subset.num_repeats, captions, subset.is_reg, img_path)
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
|
||||
@@ -21,6 +21,41 @@ def fire_in_thread(f, *args, **kwargs):
|
||||
threading.Thread(target=f, args=args, kwargs=kwargs).start()
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(
|
||||
self, image_key: str, num_repeats: int, captions: Optional[Union[str, list[str]]], is_reg: bool, absolute_path: str
|
||||
) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.captions: Optional[list[str]] = None if captions is None else ([captions] if isinstance(captions, str) else captions)
|
||||
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
|
||||
self.list_of_tags: Optional[list[str]] = None
|
||||
self.tags_weights: Optional[list[float]] = None
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_cache_path: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
# crop left top right bottom in original pixel size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image. None if not the latents is cached
|
||||
self.text_encoder_outputs_cache_path: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
|
||||
# new
|
||||
self.text_encoder_outputs: Optional[list[list[torch.Tensor]]] = None
|
||||
# old
|
||||
self.text_encoder_outputs1: Optional[torch.Tensor] = None
|
||||
self.text_encoder_outputs2: Optional[torch.Tensor] = None
|
||||
self.text_encoder_pool2: Optional[torch.Tensor] = None
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
|
||||
12
sd3_train.py
12
sd3_train.py
@@ -75,6 +75,12 @@ def train(args):
|
||||
)
|
||||
args.cache_text_encoder_outputs = True
|
||||
|
||||
if args.cache_text_encoder_outputs:
|
||||
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
|
||||
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
|
||||
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
|
||||
)
|
||||
|
||||
assert not args.train_text_encoder or (args.use_t5xxl_cache_only or not args.cache_text_encoder_outputs), (
|
||||
"when training text encoder, text encoder outputs must not be cached (except for T5XXL)"
|
||||
+ " / text encoderの学習時はtext encoderの出力はキャッシュできません(t5xxlのみキャッシュすることは可能です)"
|
||||
@@ -168,8 +174,8 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
args.t5xxl_max_token_length,
|
||||
args.apply_lg_attn_mask,
|
||||
)
|
||||
)
|
||||
train_dataset_group.set_current_strategies()
|
||||
@@ -278,8 +284,8 @@ def train(args):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
train_clip or args.use_t5xxl_cache_only, # if clip is trained or t5xxl is cached, caching is partial
|
||||
args.t5xxl_max_token_length,
|
||||
args.apply_lg_attn_mask,
|
||||
args.apply_t5_attn_mask,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_caching_strategy)
|
||||
|
||||
|
||||
@@ -43,6 +43,10 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
assert (
|
||||
train_dataset_group.is_text_encoder_output_cacheable()
|
||||
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
||||
assert args.apply_lg_attn_mask == args.apply_t5_attn_mask, (
|
||||
"apply_lg_attn_mask and apply_t5_attn_mask must be the same when caching text encoder outputs"
|
||||
" / text encoderの出力をキャッシュするときにはapply_lg_attn_maskとapply_t5_attn_maskは同じである必要があります"
|
||||
)
|
||||
|
||||
# prepare CLIP-L/CLIP-G/T5XXL training flags
|
||||
self.train_clip = not args.network_train_unet_only
|
||||
@@ -183,8 +187,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
is_partial=self.train_clip or self.train_t5xxl,
|
||||
max_token_length=args.t5xxl_max_token_length,
|
||||
apply_lg_attn_mask=args.apply_lg_attn_mask,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -321,7 +321,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -223,7 +223,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -195,7 +195,11 @@ def train(args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
# Text Encodes are eval and no grad
|
||||
text_encoder_output_caching_strategy = strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, False
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_token_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(text_encoder_output_caching_strategy)
|
||||
|
||||
|
||||
@@ -81,7 +81,11 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
return strategy_sdxl.SdxlTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk, None, args.skip_cache_check, is_weighted=args.weighted_captions
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
None,
|
||||
args.skip_cache_check,
|
||||
args.max_tolen_length,
|
||||
is_weighted=args.weighted_captions,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user