diff --git a/library/train_util.py b/library/train_util.py index 746a7f9d..809f0af0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -91,6 +91,7 @@ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ". try: import pillow_avif + IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) except: pass @@ -853,16 +854,11 @@ class BaseDataset(torch.utils.data.Dataset): # split by resolution batches = [] batch = [] - for info in image_infos: + print("checking cache validity...") + for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] - if info.latents_npz is not None: - info.latents, info.latents_original_size, info.latents_crop_left_top = self.load_latents_from_npz(info, False) - info.latents = torch.FloatTensor(info.latents) - - info.latents_flipped, _, _ = self.load_latents_from_npz(info, True) # might be None - if info.latents_flipped is not None: - info.latents_flipped = torch.FloatTensor(info.latents_flipped) + if info.latents_npz is not None: # fine tuning dataset continue # check disk cache exists and size of latents