feat: Support multi-resolution training with caching latents to disk

This commit is contained in:
Kohya S
2024-08-20 21:39:43 +09:00
parent 388b3b4b74
commit 6ab48b09d8
4 changed files with 93 additions and 43 deletions

View File

@@ -219,7 +219,13 @@ class LatentsCachingStrategy:
raise NotImplementedError
def _default_is_disk_cached_latents_expected(
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
multi_resolution: bool = False,
):
if not self.cache_to_disk:
return False
@@ -230,25 +236,17 @@ class LatentsCachingStrategy:
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
if "latents" + key_reso_suffix not in npz:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
@@ -257,7 +255,15 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
@@ -287,8 +293,13 @@ class LatentsCachingStrategy:
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -298,31 +309,56 @@ class LatentsCachingStrategy:
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
"""
for SD/SDXL/SD3.0
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
kwargs = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)