Refactor caching mechanism for latents and text encoder outputs, etc.

This commit is contained in:
Kohya S
2024-07-27 13:50:05 +09:00
parent 082f13658b
commit 41dee60383
21 changed files with 1786 additions and 733 deletions

View File

@@ -280,111 +280,6 @@ def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
class Sd3LatentsCachingStrategy(train_util.LatentsCachingStrategy):
SD3_LATENTS_NPZ_SUFFIX = "_sd3.npz"
def __init__(self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)
self.vae = None
def set_vae(self, vae: sd3_models.SDVAE):
self.vae = vae
def get_image_size_from_image_absolute_path(self, absolute_path: str) -> Tuple[Optional[int], Optional[int]]:
npz_file = glob.glob(os.path.splitext(absolute_path)[0] + "_*" + Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX)
if len(npz_file) == 0:
return None, None
w, h = os.path.splitext(npz_file[0])[0].split("_")[-2].split("x")
return int(w), int(h)
def get_latents_npz_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ Sd3LatentsCachingStrategy.SD3_LATENTS_NPZ_SUFFIX
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True
expected_latents_size = (bucket_reso[1] // 8, bucket_reso[0] // 8) # bucket_reso is (W, H)
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e
return True
def cache_batch_latents(self, image_infos: List[train_util.ImageInfo], flip_aug: bool, alpha_mask: bool, random_crop: bool):
img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=self.vae.device, dtype=self.vae.dtype)
with torch.no_grad():
latents_tensors = self.vae.encode(img_tensor).to("cpu")
if flip_aug:
img_tensor = torch.flip(img_tensor, dims=[3])
with torch.no_grad():
flipped_latents = self.vae.encode(img_tensor).to("cpu")
else:
flipped_latents = [None] * len(latents_tensors)
# for info, latents, flipped_latent, alpha_mask in zip(image_infos, latents_tensors, flipped_latents, alpha_masks):
for i in range(len(image_infos)):
info = image_infos[i]
latents = latents_tensors[i]
flipped_latent = flipped_latents[i]
alpha_mask = alpha_masks[i]
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
if self.cache_to_disk:
kwargs = {}
if flipped_latent is not None:
kwargs["latents_flipped"] = flipped_latent.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
info.latents_npz,
latents=latents.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
else:
info.latents = latents
if flip_aug:
info.latents_flipped = flipped_latent
info.alpha_mask = alpha_mask
if not train_util.HIGH_VRAM:
clean_memory_on_device(self.vae.device)
# region Diffusers