mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
feat: refactor latent cache format
This commit is contained in:
@@ -177,7 +177,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -180,7 +180,7 @@ def main(args):
|
||||
|
||||
# バッチへ追加
|
||||
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
|
||||
image_info.latents_npz = npz_file_name
|
||||
image_info.latents_cache_path = npz_file_name
|
||||
image_info.bucket_reso = reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.image = image
|
||||
|
||||
@@ -198,7 +198,7 @@ def train(args):
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(ae, accelerator)
|
||||
train_dataset_group.new_cache_latents(ae, accelerator, args.force_cache_precision)
|
||||
|
||||
ae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from safetensors.torch import safe_open, save_file
|
||||
import torch
|
||||
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
@@ -12,6 +13,7 @@ from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjecti
|
||||
# TODO remove circular import by moving ImageInfo to a separate file
|
||||
# from library.train_util import ImageInfo
|
||||
|
||||
from library import utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -20,6 +22,27 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
|
||||
if dtype is None:
|
||||
# all dtypes are acceptable
|
||||
return get_available_dtypes()
|
||||
|
||||
dtype = utils.str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
||||
compatible_dtypes = [torch.float32]
|
||||
if dtype.itemsize == 1: # fp8
|
||||
compatible_dtypes.append(torch.bfloat16)
|
||||
compatible_dtypes.append(torch.float16)
|
||||
compatible_dtypes.append(dtype) # add the specified: bf16, fp16, one of fp8
|
||||
return compatible_dtypes
|
||||
|
||||
|
||||
def get_available_dtypes() -> List[torch.dtype]:
|
||||
"""
|
||||
Returns the list of available dtypes for latents caching. Higher precision is preferred.
|
||||
"""
|
||||
return [torch.float32, torch.bfloat16, torch.float16, torch.float8_e4m3fn, torch.float8_e5m2]
|
||||
|
||||
|
||||
class TokenizeStrategy:
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
@@ -382,11 +405,18 @@ class LatentsCachingStrategy:
|
||||
|
||||
_strategy = None # strategy instance: actual strategy class
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
def __init__(
|
||||
self, architecture: str, latents_stride: int, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
|
||||
) -> None:
|
||||
self._architecture = architecture
|
||||
self._latents_stride = latents_stride
|
||||
self._cache_to_disk = cache_to_disk
|
||||
self._batch_size = batch_size
|
||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||
|
||||
self.load_version_warning_printed = False
|
||||
self.save_version_warning_printed = False
|
||||
|
||||
@classmethod
|
||||
def set_strategy(cls, strategy):
|
||||
if cls._strategy is not None:
|
||||
@@ -397,6 +427,14 @@ class LatentsCachingStrategy:
|
||||
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||
return cls._strategy
|
||||
|
||||
@property
|
||||
def architecture(self):
|
||||
return self._architecture
|
||||
|
||||
@property
|
||||
def latents_stride(self):
|
||||
return self._latents_stride
|
||||
|
||||
@property
|
||||
def cache_to_disk(self):
|
||||
return self._cache_to_disk
|
||||
@@ -407,69 +445,143 @@ class LatentsCachingStrategy:
|
||||
|
||||
@property
|
||||
def cache_suffix(self):
|
||||
raise NotImplementedError
|
||||
return f"_{self.architecture.lower()}.safetensors"
|
||||
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, npz_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(npz_path)[0].split("_")[-2].split("x")
|
||||
def get_image_size_from_disk_cache_path(self, absolute_path: str, cache_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
raise NotImplementedError
|
||||
def get_latents_cache_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def cache_batch_latents(self, model: Any, batch: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_key_suffix(
|
||||
self,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
if dtype is None, it returns "_32x64" for example.
|
||||
"""
|
||||
if latents_size is not None:
|
||||
expected_latents_size = latents_size # H, W
|
||||
else:
|
||||
# bucket_reso is (W, H)
|
||||
expected_latents_size = (bucket_reso[1] // self.latents_stride, bucket_reso[0] // self.latents_stride) # H, W
|
||||
|
||||
if dtype is None:
|
||||
dtype_suffix = ""
|
||||
else:
|
||||
dtype_suffix = "_" + utils.dtype_to_normalized_str(dtype)
|
||||
|
||||
# e.g. "_32x64_float16", HxW, dtype
|
||||
key_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}{dtype_suffix}"
|
||||
|
||||
return key_suffix
|
||||
|
||||
def get_compatible_latents_keys(
|
||||
self,
|
||||
keys: set[str],
|
||||
dtype: Union[str, torch.dtype],
|
||||
flip_aug: bool,
|
||||
bucket_reso: Optional[Tuple[int, int]] = None,
|
||||
latents_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
bucket_reso is (W, H), latents_size is (H, W)
|
||||
"""
|
||||
|
||||
latents_key = None
|
||||
flipped_latents_key = None
|
||||
|
||||
compatible_dtypes = get_compatible_dtypes(dtype)
|
||||
|
||||
for compat_dtype in compatible_dtypes:
|
||||
key_suffix = self.get_key_suffix(bucket_reso, latents_size, compat_dtype)
|
||||
|
||||
if latents_key is None:
|
||||
latents_key = "latents" + key_suffix
|
||||
if latents_key not in keys:
|
||||
latents_key = None
|
||||
if flip_aug and flipped_latents_key is None:
|
||||
flipped_latents_key = "latents_flipped" + key_suffix
|
||||
if flipped_latents_key not in keys:
|
||||
flipped_latents_key = None
|
||||
|
||||
if latents_key is not None and (flipped_latents_key is not None or not flip_aug):
|
||||
break
|
||||
|
||||
return latents_key, flipped_latents_key
|
||||
|
||||
def _default_is_disk_cached_latents_expected(
|
||||
self,
|
||||
latents_stride: int,
|
||||
bucket_reso: Tuple[int, int],
|
||||
npz_path: str,
|
||||
latents_cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
multi_resolution: bool = False,
|
||||
preferred_dtype: Optional[Union[str, torch.dtype]],
|
||||
):
|
||||
# multi_resolution is always enabled for any strategy
|
||||
if not self.cache_to_disk:
|
||||
return False
|
||||
if not os.path.exists(npz_path):
|
||||
if not os.path.exists(latents_cache_path):
|
||||
return False
|
||||
if self.skip_disk_cache_validity_check:
|
||||
return True
|
||||
|
||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
|
||||
# e.g. "_32x64", HxW
|
||||
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
||||
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||
return False
|
||||
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||
# safe_open locks the file, so we cannot use it for checking keys
|
||||
# with safe_open(latents_cache_path, framework="pt") as f:
|
||||
# keys = f.keys()
|
||||
with utils.MemoryEfficientSafeOpen(latents_cache_path) as f:
|
||||
keys = f.keys()
|
||||
|
||||
if alpha_mask and "alpha_mask" + key_suffix_without_dtype not in keys:
|
||||
# print(f"alpha_mask not found: {latents_cache_path}")
|
||||
return False
|
||||
|
||||
if preferred_dtype is None:
|
||||
# remove dtype suffix from keys, because any dtype is acceptable
|
||||
keys = [key.rsplit("_", 1)[0] for key in keys if not key.endswith(key_suffix_without_dtype)]
|
||||
keys = set(keys)
|
||||
if "latents" + key_suffix_without_dtype not in keys:
|
||||
# print(f"No preferred: latents {key_suffix_without_dtype} not found: {latents_cache_path}")
|
||||
return False
|
||||
if flip_aug and "latents_flipped" + key_suffix_without_dtype not in keys:
|
||||
# print(f"No preferred: latents_flipped {key_suffix_without_dtype} not found: {latents_cache_path}")
|
||||
return False
|
||||
else:
|
||||
# specific dtype or compatible dtype is required
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(
|
||||
keys, preferred_dtype, flip_aug, bucket_reso=bucket_reso
|
||||
)
|
||||
if latents_key is None or (flip_aug and flipped_latents_key is None):
|
||||
# print(f"Precise dtype not found: {latents_cache_path}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
logger.error(f"Error loading file: {latents_cache_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def _default_cache_batch_latents(
|
||||
self,
|
||||
encode_by_vae,
|
||||
vae_device,
|
||||
vae_dtype,
|
||||
image_infos: List,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
random_crop: bool,
|
||||
multi_resolution: bool = False,
|
||||
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
||||
):
|
||||
"""
|
||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||
@@ -499,13 +611,8 @@ class LatentsCachingStrategy:
|
||||
original_size = original_sizes[i]
|
||||
crop_ltrb = crop_ltrbs[i]
|
||||
|
||||
latents_size = latents.shape[1:3] # H, W
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||
|
||||
if self.cache_to_disk:
|
||||
self.save_latents_to_disk(
|
||||
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
|
||||
)
|
||||
self.save_latents_to_disk(info.latents_cache_path, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
||||
else:
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
@@ -515,56 +622,111 @@ class LatentsCachingStrategy:
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""
|
||||
for SD/SDXL
|
||||
"""
|
||||
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _default_load_latents_from_disk(
|
||||
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
if latents_stride is None:
|
||||
key_reso_suffix = ""
|
||||
else:
|
||||
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
with safe_open(cache_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, version.split("."))
|
||||
if major > 1: # or (major == 1 and minor > 0):
|
||||
if not self.load_version_warning_printed:
|
||||
self.load_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" + key_reso_suffix not in npz:
|
||||
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||
keys = f.keys()
|
||||
|
||||
latents_key, flipped_latents_key = self.get_compatible_latents_keys(keys, None, flip_aug=True, bucket_reso=bucket_reso)
|
||||
|
||||
key_suffix_without_dtype = self.get_key_suffix(bucket_reso=bucket_reso, dtype=None)
|
||||
alpha_mask_key = "alpha_mask" + key_suffix_without_dtype
|
||||
|
||||
latents = f.get_tensor(latents_key)
|
||||
flipped_latents = f.get_tensor(flipped_latents_key) if flipped_latents_key is not None else None
|
||||
alpha_mask = f.get_tensor(alpha_mask_key) if alpha_mask_key in keys else None
|
||||
|
||||
original_size = [int(metadata["width"]), int(metadata["height"])]
|
||||
crop_ltrb = metadata[f"crop_ltrb" + key_suffix_without_dtype]
|
||||
crop_ltrb = list(map(int, crop_ltrb.split(",")))
|
||||
|
||||
latents = npz["latents" + key_reso_suffix]
|
||||
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
def save_latents_to_disk(
|
||||
self,
|
||||
npz_path,
|
||||
latents_tensor,
|
||||
original_size,
|
||||
crop_ltrb,
|
||||
flipped_latents_tensor=None,
|
||||
alpha_mask=None,
|
||||
key_reso_suffix="",
|
||||
cache_path: str,
|
||||
latents_tensor: torch.Tensor,
|
||||
original_size: Tuple[int, int],
|
||||
crop_ltrb: List[int],
|
||||
flipped_latents_tensor: Optional[torch.Tensor] = None,
|
||||
alpha_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
kwargs = {}
|
||||
dtype = latents_tensor.dtype
|
||||
latents_size = latents_tensor.shape[1:3] # H, W
|
||||
tensor_dict = {}
|
||||
|
||||
if os.path.exists(npz_path):
|
||||
# load existing npz and update it
|
||||
npz = np.load(npz_path)
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
overwrite = False
|
||||
if os.path.exists(cache_path):
|
||||
# load existing safetensors and update it
|
||||
overwrite = True
|
||||
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
# we cannot use safe_open here because it locks the file
|
||||
# with safe_open(cache_path, framework="pt") as f:
|
||||
with utils.MemoryEfficientSafeOpen(cache_path) as f:
|
||||
metadata = f.metadata()
|
||||
keys = f.keys()
|
||||
for key in keys:
|
||||
tensor_dict[key] = f.get_tensor(key)
|
||||
assert metadata["architecture"] == self.architecture
|
||||
|
||||
file_version = metadata.get("format_version", "0.0.0")
|
||||
major, minor, patch = map(int, file_version.split("."))
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
self.save_version_warning_printed = True
|
||||
logger.warning(
|
||||
f"Existing latents cache file has a higher version {file_version} for {cache_path}. This may cause issues."
|
||||
)
|
||||
else:
|
||||
metadata = {}
|
||||
metadata["architecture"] = self.architecture
|
||||
metadata["width"] = f"{original_size[0]}"
|
||||
metadata["height"] = f"{original_size[1]}"
|
||||
metadata["format_version"] = "1.0.0"
|
||||
|
||||
metadata[f"crop_ltrb_{latents_size[0]}x{latents_size[1]}"] = ",".join(map(str, crop_ltrb))
|
||||
|
||||
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
|
||||
if latents_tensor is not None:
|
||||
tensor_dict["latents" + key_suffix] = latents_tensor
|
||||
if flipped_latents_tensor is not None:
|
||||
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
|
||||
tensor_dict["latents_flipped" + key_suffix] = flipped_latents_tensor
|
||||
if alpha_mask is not None:
|
||||
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
||||
np.savez(npz_path, **kwargs)
|
||||
key_suffix_without_dtype = self.get_key_suffix(latents_size=latents_size, dtype=None)
|
||||
tensor_dict["alpha_mask" + key_suffix_without_dtype] = alpha_mask
|
||||
|
||||
# remove lower precision latents if higher precision latents are already cached
|
||||
if overwrite:
|
||||
available_dtypes = get_available_dtypes()
|
||||
available_itemsize = None
|
||||
available_itemsize_flipped = None
|
||||
for dtype in available_dtypes:
|
||||
key_suffix = self.get_key_suffix(latents_size=latents_size, dtype=dtype)
|
||||
if "latents" + key_suffix in tensor_dict:
|
||||
if available_itemsize is None:
|
||||
available_itemsize = dtype.itemsize
|
||||
elif available_itemsize > dtype.itemsize:
|
||||
# if higher precision latents are already cached, remove lower precision latents
|
||||
del tensor_dict["latents" + key_suffix]
|
||||
|
||||
if "latents_flipped" + key_suffix in tensor_dict:
|
||||
if available_itemsize_flipped is None:
|
||||
available_itemsize_flipped = dtype.itemsize
|
||||
elif available_itemsize_flipped > dtype.itemsize:
|
||||
del tensor_dict["latents_flipped" + key_suffix]
|
||||
|
||||
save_file(tensor_dict, cache_path, metadata=metadata)
|
||||
|
||||
@@ -195,29 +195,25 @@ class FluxTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
|
||||
class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
FLUX_LATENTS_NPZ_SUFFIX = "_flux.npz"
|
||||
ARCHITECTURE = "flux"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
super().__init__(FluxLatentsCachingStrategy.ARCHITECTURE, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ FluxLatentsCachingStrategy.FLUX_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -225,9 +221,7 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -134,30 +134,28 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
SD_OLD_LATENTS_NPZ_SUFFIX = ".npz"
|
||||
SD_LATENTS_NPZ_SUFFIX = "_sd.npz"
|
||||
SDXL_LATENTS_NPZ_SUFFIX = "_sdxl.npz"
|
||||
ARCHITECTURE_SD = "sd"
|
||||
ARCHITECTURE_SDXL = "sdxl"
|
||||
|
||||
def __init__(self, sd: bool, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
arch = SdSdxlLatentsCachingStrategy.ARCHITECTURE_SD if sd else SdSdxlLatentsCachingStrategy.ARCHITECTURE_SDXL
|
||||
super().__init__(arch, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
self.sd = sd
|
||||
self.suffix = (
|
||||
SdSdxlLatentsCachingStrategy.SD_LATENTS_NPZ_SUFFIX if sd else SdSdxlLatentsCachingStrategy.SDXL_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return self.suffix
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
# support old .npz
|
||||
old_npz_file = os.path.splitext(absolute_path)[0] + SdSdxlLatentsCachingStrategy.SD_OLD_LATENTS_NPZ_SUFFIX
|
||||
if os.path.exists(old_npz_file):
|
||||
return old_npz_file
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
) -> bool:
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
||||
def load_latents_from_disk(
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
|
||||
@@ -382,29 +382,25 @@ class Sd3TextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
|
||||
|
||||
|
||||
class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||
ARCHITECTURE_SD3 = "sd3"
|
||||
|
||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
super().__init__(Sd3LatentsCachingStrategy.ARCHITECTURE_SD3, 8, cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||
|
||||
@property
|
||||
def cache_suffix(self) -> str:
|
||||
return Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
|
||||
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return (
|
||||
os.path.splitext(absolute_path)[0]
|
||||
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True)
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
bucket_reso: Tuple[int, int],
|
||||
cache_path: str,
|
||||
flip_aug: bool,
|
||||
alpha_mask: bool,
|
||||
preferred_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
return self._default_is_disk_cached_latents_expected(bucket_reso, cache_path, flip_aug, alpha_mask, preferred_dtype)
|
||||
|
||||
def load_latents_from_disk(
|
||||
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||
self, cache_path: str, bucket_reso: Tuple[int, int]
|
||||
) -> Tuple[torch.Tensor, List[int], List[int], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
return self._default_load_latents_from_disk(cache_path, bucket_reso)
|
||||
|
||||
# TODO remove circular dependency for ImageInfo
|
||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||
@@ -412,9 +408,7 @@ class Sd3LatentsCachingStrategy(LatentsCachingStrategy):
|
||||
vae_device = vae.device
|
||||
vae_dtype = vae.dtype
|
||||
|
||||
self._default_cache_batch_latents(
|
||||
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True
|
||||
)
|
||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
||||
|
||||
if not train_util.HIGH_VRAM:
|
||||
train_util.clean_memory_on_device(vae.device)
|
||||
|
||||
@@ -158,11 +158,10 @@ class ImageInfo:
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
self.latents: Optional[torch.Tensor] = None
|
||||
self.latents_flipped: Optional[torch.Tensor] = None
|
||||
self.latents_npz: Optional[str] = None # set in cache_latents
|
||||
self.latents_cache_path: Optional[str] = None # set in cache_latents
|
||||
self.latents_original_size: Optional[Tuple[int, int]] = None # original image size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = (
|
||||
None # crop left top right bottom in original pixel size, not latents size
|
||||
)
|
||||
# crop left top right bottom in original pixel size, not latents size
|
||||
self.latents_crop_ltrb: Optional[Tuple[int, int]] = None
|
||||
self.cond_img_path: Optional[str] = None
|
||||
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||
self.text_encoder_outputs_npz: Optional[str] = None # set in cache_text_encoder_outputs
|
||||
@@ -323,6 +322,9 @@ class BucketManager:
|
||||
else:
|
||||
resized_width = bucket_reso[0]
|
||||
resized_height = bucket_reso[0] / image_ar
|
||||
resized_width = int(resized_width + 0.5)
|
||||
resized_height = int(resized_height + 0.5)
|
||||
|
||||
crop_left = (bucket_reso[0] - resized_width) // 2
|
||||
crop_top = (bucket_reso[1] - resized_height) // 2
|
||||
crop_right = crop_left + resized_width
|
||||
@@ -1040,7 +1042,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
)
|
||||
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator):
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool = False):
|
||||
r"""
|
||||
a brand new method to cache latents. This method caches latents with caching strategy.
|
||||
normal cache_latents method is used by default, but this method is used when caching strategy is specified.
|
||||
@@ -1094,17 +1096,18 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
try:
|
||||
# iterate images
|
||||
logger.info("caching latents...")
|
||||
logger.info(f"Caching latents for dataset with {len(image_infos)} images.")
|
||||
preferred_dtype = model.dtype if force_cache_precision else None
|
||||
for i, info in enumerate(tqdm(image_infos)):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
if info.latents_cache_path is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||
info.latents_cache_path = caching_strategy.get_latents_cache_path(info.absolute_path, info.image_size)
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
@@ -1114,7 +1117,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# print(f"{process_index}/{num_processes} {i}/{len(image_infos)} {info.latents_npz}")
|
||||
|
||||
cache_available = caching_strategy.is_disk_cached_latents_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
info.bucket_reso, info.latents_cache_path, subset.flip_aug, subset.alpha_mask, preferred_dtype
|
||||
)
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
@@ -1144,81 +1147,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
finally:
|
||||
executor.shutdown()
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
# マルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching latents.")
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution and some conditions
|
||||
class Condition:
|
||||
def __init__(self, reso, flip_aug, alpha_mask, random_crop):
|
||||
self.reso = reso
|
||||
self.flip_aug = flip_aug
|
||||
self.alpha_mask = alpha_mask
|
||||
self.random_crop = random_crop
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
)
|
||||
|
||||
batches: List[Tuple[Condition, List[ImageInfo]]] = []
|
||||
batch: List[ImageInfo] = []
|
||||
current_condition = None
|
||||
|
||||
logger.info("checking cache validity...")
|
||||
for info in tqdm(image_infos):
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None: # fine tuning dataset
|
||||
continue
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if cache_to_disk:
|
||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
cache_available = is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask
|
||||
)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
continue
|
||||
|
||||
# if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty
|
||||
condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
|
||||
batch.append(info)
|
||||
current_condition = condition
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append((current_condition, batch))
|
||||
batch = []
|
||||
current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append((current_condition, batch))
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||
logger.info("caching latents...")
|
||||
for condition, batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
|
||||
r"""
|
||||
a brand new method to cache text encoder outputs. This method caches text encoder outputs with caching strategy.
|
||||
@@ -1275,131 +1203,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop)
|
||||
caching_strategy.cache_batch_outputs(tokenize_strategy, models, text_encoding_strategy, batch)
|
||||
|
||||
# if weight_dtype is specified, Text Encoder itself and output will be converted to the dtype
|
||||
# this method is only for SDXL, but it should be implemented here because it needs to be a method of dataset
|
||||
# to support SD1/2, it needs a flag for v2, but it is postponed
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, output_dtype, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
assert len(tokenizers) == 2, "only support SDXL"
|
||||
return self.cache_text_encoder_outputs_common(
|
||||
tokenizers, text_encoders, [device, device], output_dtype, [output_dtype], cache_to_disk, is_main_process
|
||||
)
|
||||
|
||||
# same as above, but for SD3
|
||||
def cache_text_encoder_outputs_sd3(
|
||||
self, tokenizer, text_encoders, devices, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
|
||||
):
|
||||
return self.cache_text_encoder_outputs_common(
|
||||
[tokenizer],
|
||||
text_encoders,
|
||||
devices,
|
||||
output_dtype,
|
||||
te_dtypes,
|
||||
cache_to_disk,
|
||||
is_main_process,
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
def cache_text_encoder_outputs_common(
|
||||
self,
|
||||
tokenizers,
|
||||
text_encoders,
|
||||
devices,
|
||||
output_dtype,
|
||||
te_dtypes,
|
||||
cache_to_disk=False,
|
||||
is_main_process=True,
|
||||
file_suffix=TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX,
|
||||
batch_size=None,
|
||||
):
|
||||
# latentsのキャッシュと同様に、ディスクへのキャッシュに対応する
|
||||
# またマルチGPUには対応していないので、そちらはtools/cache_latents.pyを使うこと
|
||||
logger.info("caching text encoder outputs.")
|
||||
|
||||
tokenize_strategy = TokenizeStrategy.get_strategy()
|
||||
|
||||
if batch_size is None:
|
||||
batch_size = self.batch_size
|
||||
|
||||
image_infos = list(self.image_data.values())
|
||||
|
||||
logger.info("checking cache existence...")
|
||||
image_infos_to_cache = []
|
||||
for info in tqdm(image_infos):
|
||||
# subset = self.image_to_subset[info.image_key]
|
||||
if cache_to_disk:
|
||||
te_out_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.text_encoder_outputs_npz = te_out_npz
|
||||
|
||||
if not is_main_process: # store to info only
|
||||
continue
|
||||
|
||||
if os.path.exists(te_out_npz):
|
||||
# TODO check varidity of cache here
|
||||
continue
|
||||
|
||||
image_infos_to_cache.append(info)
|
||||
|
||||
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||
return
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
|
||||
text_encoder.to(device)
|
||||
if te_dtype is not None:
|
||||
text_encoder.to(dtype=te_dtype)
|
||||
|
||||
# create batch
|
||||
is_sd3 = len(tokenizers) == 1
|
||||
batch = []
|
||||
batches = []
|
||||
for info in image_infos_to_cache:
|
||||
if not is_sd3:
|
||||
input_ids1 = self.get_input_ids(info.caption, tokenizers[0])
|
||||
input_ids2 = self.get_input_ids(info.caption, tokenizers[1])
|
||||
batch.append((info, input_ids1, input_ids2))
|
||||
else:
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(info.caption)
|
||||
batch.append((info, l_tokens, g_tokens, t5_tokens))
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
logger.info("caching text encoder outputs...")
|
||||
if not is_sd3:
|
||||
for batch in tqdm(batches):
|
||||
infos, input_ids1, input_ids2 = zip(*batch)
|
||||
input_ids1 = torch.stack(input_ids1, dim=0)
|
||||
input_ids2 = torch.stack(input_ids2, dim=0)
|
||||
cache_batch_text_encoder_outputs(
|
||||
infos, tokenizers, text_encoders, self.max_token_length, cache_to_disk, input_ids1, input_ids2, output_dtype
|
||||
)
|
||||
else:
|
||||
for batch in tqdm(batches):
|
||||
infos, l_tokens, g_tokens, t5_tokens = zip(*batch)
|
||||
|
||||
# stack tokens
|
||||
# l_tokens = [tokens[0] for tokens in l_tokens]
|
||||
# g_tokens = [tokens[0] for tokens in g_tokens]
|
||||
# t5_tokens = [tokens[0] for tokens in t5_tokens]
|
||||
|
||||
cache_batch_text_encoder_outputs_sd3(
|
||||
infos,
|
||||
tokenizers[0],
|
||||
text_encoders,
|
||||
self.max_token_length,
|
||||
cache_to_disk,
|
||||
(l_tokens, g_tokens, t5_tokens),
|
||||
output_dtype,
|
||||
)
|
||||
|
||||
def get_image_size(self, image_path):
|
||||
# return imagesize.get(image_path)
|
||||
image_size = imagesize.get(image_path)
|
||||
@@ -1522,17 +1325,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1])
|
||||
|
||||
image = None
|
||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
elif image_info.latents_cache_path is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
|
||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_cache_path, image_info.bucket_reso)
|
||||
)
|
||||
if flipped:
|
||||
latents = flipped_latents
|
||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem
|
||||
alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1]
|
||||
del flipped_latents
|
||||
latents = torch.FloatTensor(latents)
|
||||
if alpha_mask is not None:
|
||||
alpha_mask = torch.FloatTensor(alpha_mask)
|
||||
|
||||
image = None
|
||||
else:
|
||||
@@ -1885,28 +1685,28 @@ class DreamBoothDataset(BaseDataset):
|
||||
if strategy is not None:
|
||||
logger.info("get image size from name of cache files")
|
||||
|
||||
# make image path to npz path mapping
|
||||
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
|
||||
npz_paths.sort(
|
||||
# make image path to cache path mapping
|
||||
cache_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
|
||||
cache_paths.sort(
|
||||
key=lambda item: item.rsplit("_", maxsplit=2)[0]
|
||||
) # sort by name excluding resolution and cache_suffix
|
||||
npz_path_index = 0
|
||||
cache_path_index = 0
|
||||
|
||||
size_set_count = 0
|
||||
for i, img_path in enumerate(tqdm(img_paths)):
|
||||
l = len(os.path.splitext(img_path)[0]) # remove extension
|
||||
found = False
|
||||
while npz_path_index < len(npz_paths): # until found or end of npz_paths
|
||||
while cache_path_index < len(cache_paths): # until found or end of npz_paths
|
||||
# npz_paths are sorted, so if npz_path > img_path, img_path is not found
|
||||
if npz_paths[npz_path_index][:l] > img_path[:l]:
|
||||
if cache_paths[cache_path_index][:l] > img_path[:l]:
|
||||
break
|
||||
if npz_paths[npz_path_index][:l] == img_path[:l]: # found
|
||||
if cache_paths[cache_path_index][:l] == img_path[:l]: # found
|
||||
found = True
|
||||
break
|
||||
npz_path_index += 1 # next npz_path
|
||||
cache_path_index += 1 # next npz_path
|
||||
|
||||
if found:
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(img_path, npz_paths[npz_path_index])
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(img_path, cache_paths[cache_path_index])
|
||||
else:
|
||||
w, h = None, None
|
||||
|
||||
@@ -2139,8 +1939,8 @@ class FineTuningDataset(BaseDataset):
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
# if cache exists, use them
|
||||
image_info.latents_cache_path, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
@@ -2161,7 +1961,7 @@ class FineTuningDataset(BaseDataset):
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_npz is not None
|
||||
has_npz = image_info.latents_cache_path is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
@@ -2233,7 +2033,7 @@ class FineTuningDataset(BaseDataset):
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
image_info.latents_cache_path = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
@@ -2382,11 +2182,8 @@ class ControlNetDataset(BaseDataset):
|
||||
self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager
|
||||
self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||
return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator):
|
||||
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator)
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_latents(model, accelerator, force_cache_precision)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], is_main_process: bool):
|
||||
return self.dreambooth_dataset_delegate.new_cache_text_encoder_outputs(models, is_main_process)
|
||||
@@ -2485,33 +2282,12 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, file_suffix=".npz"):
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator, force_cache_precision: bool):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, file_suffix)
|
||||
|
||||
def new_cache_latents(self, model: Any, accelerator: Accelerator):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.new_cache_latents(model, accelerator)
|
||||
dataset.new_cache_latents(model, accelerator, force_cache_precision)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
def cache_text_encoder_outputs(
|
||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
||||
|
||||
def cache_text_encoder_outputs_sd3(
|
||||
self, tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk=False, is_main_process=True, batch_size=None
|
||||
):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
dataset.cache_text_encoder_outputs_sd3(
|
||||
tokenizer, text_encoders, device, output_dtype, te_dtypes, cache_to_disk, is_main_process, batch_size
|
||||
)
|
||||
|
||||
def new_cache_text_encoder_outputs(self, models: List[Any], accelerator: Accelerator):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
logger.info(f"[Dataset {i}]")
|
||||
@@ -2556,72 +2332,6 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.disable_token_padding()
|
||||
|
||||
|
||||
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
|
||||
if not os.path.exists(npz_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver?
|
||||
return False
|
||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if flip_aug:
|
||||
if "latents_flipped" not in npz:
|
||||
return False
|
||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||
return False
|
||||
|
||||
if alpha_mask:
|
||||
if "alpha_mask" not in npz:
|
||||
return False
|
||||
if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso
|
||||
return False
|
||||
else:
|
||||
if "alpha_mask" in npz:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading file: {npz_path}")
|
||||
raise e
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
# TODO update to use CachingStrategy
|
||||
# def load_latents_from_disk(
|
||||
# npz_path,
|
||||
# ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
# npz = np.load(npz_path)
|
||||
# if "latents" not in npz:
|
||||
# raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
||||
|
||||
# latents = npz["latents"]
|
||||
# original_size = npz["original_size"].tolist()
|
||||
# crop_ltrb = npz["crop_ltrb"].tolist()
|
||||
# flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||
# alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
||||
# return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||
|
||||
|
||||
# def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None):
|
||||
# kwargs = {}
|
||||
# if flipped_latents_tensor is not None:
|
||||
# kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||
# if alpha_mask is not None:
|
||||
# kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
||||
# np.savez(
|
||||
# npz_path,
|
||||
# latents=latents_tensor.float().cpu().numpy(),
|
||||
# original_size=np.array(original_size),
|
||||
# crop_ltrb=np.array(crop_ltrb),
|
||||
# **kwargs,
|
||||
# )
|
||||
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
logger.info(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
logger.info(
|
||||
@@ -2865,19 +2575,19 @@ def trim_and_resize_if_required(
|
||||
# for new_cache_latents
|
||||
def load_images_and_masks_for_caching(
|
||||
image_infos: List[ImageInfo], use_alpha_mask: bool, random_crop: bool
|
||||
) -> Tuple[torch.Tensor, List[np.ndarray], List[Tuple[int, int]], List[Tuple[int, int, int, int]]]:
|
||||
) -> Tuple[torch.Tensor, list[Optional[torch.Tensor]], list[Tuple[int, int]], list[Tuple[int, int, int, int]]]:
|
||||
r"""
|
||||
requires image_infos to have: [absolute_path or image], bucket_reso, resized_size
|
||||
|
||||
returns: image_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||
|
||||
image_tensor: torch.Tensor = torch.Size([B, 3, H, W]), ...], normalized to [-1, 1]
|
||||
alpha_masks: List[np.ndarray] = [np.ndarray([H, W]), ...], normalized to [0, 1]
|
||||
alpha_masks: List[torch.Tensor] = [torch.Size([H, W]), ...], List of None if not use_alpha_mask
|
||||
original_sizes: List[Tuple[int, int]] = [(W, H), ...]
|
||||
crop_ltrbs: List[Tuple[int, int, int, int]] = [(L, T, R, B), ...]
|
||||
"""
|
||||
images: List[torch.Tensor] = []
|
||||
alpha_masks: List[np.ndarray] = []
|
||||
alpha_masks: List[torch.Tensor] = []
|
||||
original_sizes: List[Tuple[int, int]] = []
|
||||
crop_ltrbs: List[Tuple[int, int, int, int]] = []
|
||||
for info in image_infos:
|
||||
@@ -2907,158 +2617,6 @@ def load_images_and_masks_for_caching(
|
||||
return img_tensor, alpha_masks, original_sizes, crop_ltrbs
|
||||
|
||||
|
||||
def cache_batch_latents(
|
||||
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool
|
||||
) -> None:
|
||||
r"""
|
||||
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||
optionally requires image_infos to have: image
|
||||
if cache_to_disk is True, set info.latents_npz
|
||||
flipped latents is also saved if flip_aug is True
|
||||
if cache_to_disk is False, set info.latents
|
||||
latents_flipped is also set if flip_aug is True
|
||||
latents_original_size and latents_crop_ltrb are also set
|
||||
"""
|
||||
images = []
|
||||
alpha_masks: List[np.ndarray] = []
|
||||
for info in image_infos:
|
||||
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
|
||||
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
|
||||
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||
|
||||
info.latents_original_size = original_size
|
||||
info.latents_crop_ltrb = crop_ltrb
|
||||
|
||||
if use_alpha_mask:
|
||||
if image.shape[2] == 4:
|
||||
alpha_mask = image[:, :, 3] # [H,W]
|
||||
alpha_mask = alpha_mask.astype(np.float32) / 255.0
|
||||
alpha_mask = torch.FloatTensor(alpha_mask) # [H,W]
|
||||
else:
|
||||
alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W]
|
||||
else:
|
||||
alpha_mask = None
|
||||
alpha_masks.append(alpha_mask)
|
||||
|
||||
image = image[:, :, :3] # remove alpha channel if exists
|
||||
image = IMAGE_TRANSFORMS(image)
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
|
||||
if flip_aug:
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
with torch.no_grad():
|
||||
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
else:
|
||||
flipped_latents = [None] * len(latents)
|
||||
|
||||
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
# save_latents_to_disk(
|
||||
# info.latents_npz,
|
||||
# latent,
|
||||
# info.latents_original_size,
|
||||
# info.latents_crop_ltrb,
|
||||
# flipped_latent,
|
||||
# alpha_mask,
|
||||
# )
|
||||
pass
|
||||
else:
|
||||
info.latents = latent
|
||||
if flip_aug:
|
||||
info.latents_flipped = flipped_latent
|
||||
info.alpha_mask = alpha_mask
|
||||
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(vae.device)
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||
):
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
max_token_length,
|
||||
input_ids1,
|
||||
input_ids2,
|
||||
tokenizers[0],
|
||||
tokenizers[1],
|
||||
text_encoders[0],
|
||||
text_encoders[1],
|
||||
dtype,
|
||||
)
|
||||
|
||||
# ここでcpuに移動しておかないと、上書きされてしまう
|
||||
b_hidden_state1 = b_hidden_state1.detach().to("cpu") # b,n*75+2,768
|
||||
b_hidden_state2 = b_hidden_state2.detach().to("cpu") # b,n*75+2,1280
|
||||
b_pool2 = b_pool2.detach().to("cpu") # b,1280
|
||||
|
||||
for info, hidden_state1, hidden_state2, pool2 in zip(image_infos, b_hidden_state1, b_hidden_state2, b_pool2):
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, hidden_state1, hidden_state2, pool2)
|
||||
else:
|
||||
info.text_encoder_outputs1 = hidden_state1
|
||||
info.text_encoder_outputs2 = hidden_state2
|
||||
info.text_encoder_pool2 = pool2
|
||||
|
||||
|
||||
def cache_batch_text_encoder_outputs_sd3(
|
||||
image_infos, tokenizer, text_encoders, max_token_length, cache_to_disk, input_ids, output_dtype
|
||||
):
|
||||
# make input_ids for each text encoder
|
||||
l_tokens, g_tokens, t5_tokens = input_ids
|
||||
|
||||
clip_l, clip_g, t5xxl = text_encoders
|
||||
with torch.no_grad():
|
||||
b_lg_out, b_t5_out, b_pool = sd3_utils.get_cond_from_tokens(
|
||||
l_tokens, g_tokens, t5_tokens, clip_l, clip_g, t5xxl, "cpu", output_dtype
|
||||
)
|
||||
b_lg_out = b_lg_out.detach()
|
||||
b_t5_out = b_t5_out.detach()
|
||||
b_pool = b_pool.detach()
|
||||
|
||||
for info, lg_out, t5_out, pool in zip(image_infos, b_lg_out, b_t5_out, b_pool):
|
||||
# debug: NaN check
|
||||
if torch.isnan(lg_out).any() or torch.isnan(t5_out).any() or torch.isnan(pool).any():
|
||||
raise RuntimeError(f"NaN detected in text encoder outputs: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
save_text_encoder_outputs_to_disk(info.text_encoder_outputs_npz, lg_out, t5_out, pool)
|
||||
else:
|
||||
info.text_encoder_outputs1 = lg_out
|
||||
info.text_encoder_outputs2 = t5_out
|
||||
info.text_encoder_pool2 = pool
|
||||
|
||||
|
||||
def save_text_encoder_outputs_to_disk(npz_path, hidden_state1, hidden_state2, pool2):
|
||||
np.savez(
|
||||
npz_path,
|
||||
hidden_state1=hidden_state1.cpu().float().numpy(),
|
||||
hidden_state2=hidden_state2.cpu().float().numpy(),
|
||||
pool2=pool2.cpu().float().numpy(),
|
||||
)
|
||||
|
||||
|
||||
def load_text_encoder_outputs_from_disk(npz_path):
|
||||
with np.load(npz_path) as f:
|
||||
hidden_state1 = torch.from_numpy(f["hidden_state1"])
|
||||
hidden_state2 = torch.from_numpy(f["hidden_state2"]) if "hidden_state2" in f else None
|
||||
pool2 = torch.from_numpy(f["pool2"]) if "pool2" in f else None
|
||||
return hidden_state1, hidden_state2, pool2
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region モジュール入れ替え部
|
||||
@@ -4357,6 +3915,12 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force_cache_precision",
|
||||
action="store_true",
|
||||
help="force cache precision to match the model precision. this option re-caches latents if the precision is lower than the model precision"
|
||||
" / cacheの精度をモデルの精度に合わせる。このオプションを指定すると、精度がモデルの精度よりも低い場合にlatentを再キャッシュします",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_cache_check",
|
||||
action="store_true",
|
||||
@@ -5913,7 +5477,7 @@ def append_lr_to_logs(logs, lr_scheduler, optimizer_type, including_unet=True):
|
||||
names.append("unet")
|
||||
names.append("text_encoder1")
|
||||
names.append("text_encoder2")
|
||||
names.append("text_encoder3") # SD3
|
||||
names.append("text_encoder3") # SD3
|
||||
|
||||
append_lr_to_logs_with_names(logs, lr_scheduler, optimizer_type, names)
|
||||
|
||||
|
||||
@@ -189,6 +189,15 @@ def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None)
|
||||
raise ValueError(f"Unsupported dtype: {s}")
|
||||
|
||||
|
||||
def dtype_to_normalized_str(dtype: Union[str, torch.dtype]) -> str:
|
||||
dtype = str_to_dtype(dtype) if isinstance(dtype, str) else dtype
|
||||
|
||||
# get name of the dtype
|
||||
dtype_name = str(dtype).split(".")[-1]
|
||||
|
||||
return dtype_name
|
||||
|
||||
|
||||
def mem_eff_save_file(tensors: Dict[str, torch.Tensor], filename: str, metadata: Dict[str, Any] = None):
|
||||
"""
|
||||
memory efficient save file
|
||||
@@ -264,8 +273,8 @@ class MemoryEfficientSafeOpen:
|
||||
# does not support metadata loading
|
||||
def __init__(self, filename):
|
||||
self.filename = filename
|
||||
self.header, self.header_size = self._read_header()
|
||||
self.file = open(filename, "rb")
|
||||
self.header, self.header_size = self._read_header()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
@@ -276,6 +285,9 @@ class MemoryEfficientSafeOpen:
|
||||
def keys(self):
|
||||
return [k for k in self.header.keys() if k != "__metadata__"]
|
||||
|
||||
def metadata(self) -> Dict[str, str]:
|
||||
return self.header.get("__metadata__", {})
|
||||
|
||||
def get_tensor(self, key):
|
||||
if key not in self.header:
|
||||
raise KeyError(f"Tensor '{key}' not found in the file")
|
||||
@@ -293,10 +305,9 @@ class MemoryEfficientSafeOpen:
|
||||
return self._deserialize_tensor(tensor_bytes, metadata)
|
||||
|
||||
def _read_header(self):
|
||||
with open(self.filename, "rb") as f:
|
||||
header_size = struct.unpack("<Q", f.read(8))[0]
|
||||
header_json = f.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
header_size = struct.unpack("<Q", self.file.read(8))[0]
|
||||
header_json = self.file.read(header_size).decode("utf-8")
|
||||
return json.loads(header_json), header_size
|
||||
|
||||
def _deserialize_tensor(self, tensor_bytes, metadata):
|
||||
dtype = self._get_torch_dtype(metadata["dtype"])
|
||||
|
||||
@@ -330,7 +330,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu") # if no sampling, vae can be deleted
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -272,7 +272,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -209,7 +209,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -181,7 +181,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -149,7 +149,7 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
|
||||
# cache latents with dataset
|
||||
# TODO use DataLoader to speed up
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.print(f"Finished caching latents to disk.")
|
||||
|
||||
@@ -156,7 +156,7 @@ def train(args):
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -418,7 +418,7 @@ class NetworkTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -378,7 +378,7 @@ class TextualInversionTrainer:
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||
train_dataset_group.new_cache_latents(vae, accelerator, args.force_cache_precision)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
Reference in New Issue
Block a user