From a8632b7329b8c6f558f0c707b21d5ead40cb33cb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:14:39 +0900 Subject: [PATCH] fix latents disk cache --- library/train_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 6b398707..013cc81c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -758,15 +758,15 @@ class BaseDataset(torch.utils.data.Dataset): cache_available = False expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 if os.path.exists(info.latents_npz): - cached_latents = np.load(info.latents_npz) - if cached_latents["latents"].shape[1:3] == expected_latents_size: + cached_latents = np.load(info.latents_npz)["arr_0"] + if cached_latents.shape[1:3] == expected_latents_size: cache_available = True if subset.flip_aug: cache_available = False if os.path.exists(info.latents_npz_flipped): - cached_latents_flipped = np.load(info.latents_npz_flipped) - if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size: + cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"] + if cached_latents_flipped.shape[1:3] == expected_latents_size: cache_available = True if cache_available: