diff --git a/library/train_util.py b/library/train_util.py index 6b398707..013cc81c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -758,15 +758,15 @@ class BaseDataset(torch.utils.data.Dataset): cache_available = False expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 if os.path.exists(info.latents_npz): - cached_latents = np.load(info.latents_npz) - if cached_latents["latents"].shape[1:3] == expected_latents_size: + cached_latents = np.load(info.latents_npz)["arr_0"] + if cached_latents.shape[1:3] == expected_latents_size: cache_available = True if subset.flip_aug: cache_available = False if os.path.exists(info.latents_npz_flipped): - cached_latents_flipped = np.load(info.latents_npz_flipped) - if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size: + cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"] + if cached_latents_flipped.shape[1:3] == expected_latents_size: cache_available = True if cache_available: