fix different reso in batch

This commit is contained in:
Kohya S
2023-03-21 18:33:46 +09:00
parent 1816ac3271
commit 6d9f3bc0b2

View File

@@ -678,8 +678,16 @@ class BaseDataset(torch.utils.data.Dataset):
def cache_latents(self, vae, vae_batch_size=1):
# ちょっと速くした
print("caching latents.")
infos = []
for info in self.image_data.values():
image_infos= list(self.image_data.values())
# sort by resolution
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
# split by resolution
batches = []
batch = []
for info in image_infos:
subset = self.image_to_subset[info.image_key]
if info.latents_npz is not None:
@@ -690,10 +698,23 @@ class BaseDataset(torch.utils.data.Dataset):
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
continue
infos.append(info)
# if last member of batch has different resolution, flush the batch
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
batches.append(batch)
batch = []
for i in tqdm(range(0, len(infos), vae_batch_size), smoothing=1, total=len(infos) // vae_batch_size):
batch = infos[i : i + vae_batch_size]
batch.append(info)
# if number of data in batch is enough, flush the batch
if len(batch) >= vae_batch_size:
batches.append(batch)
batch = []
if len(batch) > 0:
batches.append(batch)
# iterate batches
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
for info in batch:
image = self.load_image(info.absolute_path)