diff --git a/library/train_util.py b/library/train_util.py index 1d616376..534771ad 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1009,6 +1009,10 @@ class BaseDataset(torch.utils.data.Dataset): return None, None, None npz = np.load(npz_file) + if "latents" not in npz: + print(f"error: npz is old format. please re-generate {npz_file}") + return None, None, None + latents = npz["latents"] original_size = npz["original_size"].tolist() crop_left_top = npz["crop_left_top"].tolist()