diff --git a/library/train_util.py b/library/train_util.py index 8dbd4dbb..ec352a1e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -678,8 +678,16 @@ class BaseDataset(torch.utils.data.Dataset): def cache_latents(self, vae, vae_batch_size=1): # ちょっと速くした print("caching latents.") - infos = [] - for info in self.image_data.values(): + + image_infos= list(self.image_data.values()) + + # sort by resolution + image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) + + # split by resolution + batches = [] + batch = [] + for info in image_infos: subset = self.image_to_subset[info.image_key] if info.latents_npz is not None: @@ -690,10 +698,23 @@ class BaseDataset(torch.utils.data.Dataset): info.latents_flipped = torch.FloatTensor(info.latents_flipped) continue - infos.append(info) + # if last member of batch has different resolution, flush the batch + if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: + batches.append(batch) + batch = [] - for i in tqdm(range(0, len(infos), vae_batch_size), smoothing=1, total=len(infos) // vae_batch_size): - batch = infos[i : i + vae_batch_size] + batch.append(info) + + # if number of data in batch is enough, flush the batch + if len(batch) >= vae_batch_size: + batches.append(batch) + batch = [] + + if len(batch) > 0: + batches.append(batch) + + # iterate batches + for batch in tqdm(batches, smoothing=1, total=len(batches)): images = [] for info in batch: image = self.load_image(info.absolute_path)