fix latents disk cache

This commit is contained in:
Kohya S
2023-04-13 21:14:39 +09:00
parent 9ff32fd4c0
commit a8632b7329

View File

@@ -758,15 +758,15 @@ class BaseDataset(torch.utils.data.Dataset):
cache_available = False cache_available = False
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意 expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
if os.path.exists(info.latents_npz): if os.path.exists(info.latents_npz):
cached_latents = np.load(info.latents_npz) cached_latents = np.load(info.latents_npz)["arr_0"]
if cached_latents["latents"].shape[1:3] == expected_latents_size: if cached_latents.shape[1:3] == expected_latents_size:
cache_available = True cache_available = True
if subset.flip_aug: if subset.flip_aug:
cache_available = False cache_available = False
if os.path.exists(info.latents_npz_flipped): if os.path.exists(info.latents_npz_flipped):
cached_latents_flipped = np.load(info.latents_npz_flipped) cached_latents_flipped = np.load(info.latents_npz_flipped)["arr_0"]
if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size: if cached_latents_flipped.shape[1:3] == expected_latents_size:
cache_available = True cache_available = True
if cache_available: if cache_available: