mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
WIP: update new latents caching
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
@@ -282,12 +283,26 @@ def sample_images(*args, **kwargs):
|
|||||||
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
||||||
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
|
||||||
|
|
||||||
def __init__(self, vae: sd3_models.SDVAE, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
|
||||||
|
self.vae = None
|
||||||
|
|
||||||
|
def set_vae(self, vae: sd3_models.SDVAE):
|
||||||
self.vae = vae
|
self.vae = vae
|
||||||
|
|
||||||
def get_latents_npz_path(self, absolute_path: str):
|
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||||
return os.path.splitext(absolute_path)[0] + self.SD3_LATENTS_NPZ_SUFFIX
|
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
|
||||||
|
if len(npz_file) == 0:
|
||||||
|
return None, None
|
||||||
|
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
|
||||||
|
return int(w), int(h)
|
||||||
|
|
||||||
|
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||||
|
return (
|
||||||
|
os.path.splitext(absolute_path)[0]
|
||||||
|
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
|
||||||
|
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
if not self.cache_to_disk:
|
if not self.cache_to_disk:
|
||||||
@@ -331,24 +346,24 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
|||||||
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
|
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
latents = self.vae.encode(img_tensor).to("cpu")
|
latents_tensors = self.vae.encode(img_tensor).to("cpu")
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
img_tensor = torch.flip(img_tensor, dims=[3])
|
img_tensor = torch.flip(img_tensor, dims=[3])
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
flipped_latents = self.vae.encode(img_tensor).to("cpu")
|
flipped_latents = self.vae.encode(img_tensor).to("cpu")
|
||||||
else:
|
else:
|
||||||
flipped_latents = [None] * len(latents)
|
flipped_latents = [None] * len(latents_tensors)
|
||||||
|
|
||||||
|
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
|
||||||
|
for i in range(len(image_infos)):
|
||||||
|
info = image_infos[i]
|
||||||
|
latents = latents_tensors[i]
|
||||||
|
flipped_latent = flipped_latents[i]
|
||||||
|
alpha_mask = alpha_masks[i]
|
||||||
|
original_size = original_sizes[i]
|
||||||
|
crop_ltrb = crop_ltrbs[i]
|
||||||
|
|
||||||
for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks):
|
|
||||||
if self.cache_to_disk:
|
if self.cache_to_disk:
|
||||||
# save_latents_to_disk(
|
|
||||||
# info.latents_npz,
|
|
||||||
# latent,
|
|
||||||
# info.latents_original_size,
|
|
||||||
# info.latents_crop_ltrb,
|
|
||||||
# flipped_latent,
|
|
||||||
# alpha_mask,
|
|
||||||
# )
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if flipped_latent is not None:
|
if flipped_latent is not None:
|
||||||
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
|
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
|
||||||
@@ -357,12 +372,12 @@ class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
|
|||||||
np.savez(
|
np.savez(
|
||||||
info.latents_npz,
|
info.latents_npz,
|
||||||
latents=latents.float().cpu().numpy(),
|
latents=latents.float().cpu().numpy(),
|
||||||
original_size=np.array(original_sizes),
|
original_size=np.array(original_size),
|
||||||
crop_ltrb=np.array(crop_ltrbs),
|
crop_ltrb=np.array(crop_ltrb),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
info.latents = latent
|
info.latents = latents
|
||||||
if flip_aug:
|
if flip_aug:
|
||||||
info.latents_flipped = flipped_latent
|
info.latents_flipped = flipped_latent
|
||||||
info.alpha_mask = alpha_mask
|
info.alpha_mask = alpha_mask
|
||||||
|
|||||||
@@ -360,11 +360,23 @@ class AugHelper:
|
|||||||
|
|
||||||
|
|
||||||
class LatentsCachingStrategy:
|
class LatentsCachingStrategy:
|
||||||
|
_strategy = None # strategy instance: actual strategy class
|
||||||
|
|
||||||
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
|
||||||
self._cache_to_disk = cache_to_disk
|
self._cache_to_disk = cache_to_disk
|
||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
self.skip_disk_cache_validity_check = skip_disk_cache_validity_check
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_strategy(cls, strategy):
|
||||||
|
if cls._strategy is not None:
|
||||||
|
raise RuntimeError(f"Internal error. {cls.__name__} strategy is already set")
|
||||||
|
cls._strategy = strategy
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_strategy(cls) -> Optional["LatentsCachingStrategy"]:
|
||||||
|
return cls._strategy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cache_to_disk(self):
|
def cache_to_disk(self):
|
||||||
return self._cache_to_disk
|
return self._cache_to_disk
|
||||||
@@ -373,10 +385,15 @@ class LatentsCachingStrategy:
|
|||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
|
|
||||||
def get_latents_npz_path(self, absolute_path: str):
|
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
def get_latents_npz_path(self, absolute_path: str, bucket_reso: Tuple[int, int]) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def is_disk_cached_latents_expected(
|
||||||
|
self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
||||||
|
) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
def cache_batch_latents(self, batch: List[ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
@@ -1034,7 +1051,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
# check disk cache exists and size of latents
|
# check disk cache exists and size of latents
|
||||||
if caching_strategy.cache_to_disk:
|
if caching_strategy.cache_to_disk:
|
||||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||||
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path)
|
info.latents_npz = caching_strategy.get_latents_npz_path(info.absolute_path, info.image_size)
|
||||||
if not is_main_process: # prepare for multi-gpu, only store to info
|
if not is_main_process: # prepare for multi-gpu, only store to info
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1730,6 +1747,18 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
img_paths = glob_images(subset.image_dir, "*")
|
img_paths = glob_images(subset.image_dir, "*")
|
||||||
sizes = [None] * len(img_paths)
|
sizes = [None] * len(img_paths)
|
||||||
|
|
||||||
|
# new caching: get image size from cache files
|
||||||
|
strategy = LatentsCachingStrategy.get_strategy()
|
||||||
|
if strategy is not None:
|
||||||
|
logger.info("get image size from cache files")
|
||||||
|
size_set_count = 0
|
||||||
|
for i, img_path in enumerate(tqdm(img_paths)):
|
||||||
|
w, h = strategy.get_image_size_from_image_absolute_path(img_path)
|
||||||
|
if w is not None and h is not None:
|
||||||
|
sizes[i] = [w, h]
|
||||||
|
size_set_count += 1
|
||||||
|
logger.info(f"set image size from cache files: {size_set_count}/{len(img_paths)}")
|
||||||
|
|
||||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||||
|
|
||||||
if use_cached_info_for_subset:
|
if use_cached_info_for_subset:
|
||||||
|
|||||||
15
sd3_train.py
15
sd3_train.py
@@ -91,6 +91,15 @@ def train(args):
|
|||||||
# load tokenizer
|
# load tokenizer
|
||||||
sd3_tokenizer = sd3_models.SD3Tokenizer()
|
sd3_tokenizer = sd3_models.SD3Tokenizer()
|
||||||
|
|
||||||
|
# prepare caching strategy
|
||||||
|
if args.new_caching:
|
||||||
|
latents_caching_strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
|
||||||
|
args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
latents_caching_strategy = None
|
||||||
|
train_util.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||||
@@ -217,10 +226,8 @@ def train(args):
|
|||||||
file_suffix="_sd3.npz",
|
file_suffix="_sd3.npz",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
strategy = sd3_train_utils.Sd3LatentsCachingStrategy(
|
latents_caching_strategy.set_vae(vae)
|
||||||
vae, args.cache_latents_to_disk, args.vae_batch_size, args.skip_latents_validity_check
|
train_dataset_group.new_cache_latents(accelerator.is_main_process, latents_caching_strategy)
|
||||||
)
|
|
||||||
train_dataset_group.new_cache_latents(accelerator.is_main_process, strategy)
|
|
||||||
vae.to("cpu") # if no sampling, vae can be deleted
|
vae.to("cpu") # if no sampling, vae can be deleted
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user