mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix different reso in batch
This commit is contained in:
@@ -678,8 +678,16 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
# ちょっと速くした
|
||||
print("caching latents.")
|
||||
infos = []
|
||||
for info in self.image_data.values():
|
||||
|
||||
image_infos= list(self.image_data.values())
|
||||
|
||||
# sort by resolution
|
||||
image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1])
|
||||
|
||||
# split by resolution
|
||||
batches = []
|
||||
batch = []
|
||||
for info in image_infos:
|
||||
subset = self.image_to_subset[info.image_key]
|
||||
|
||||
if info.latents_npz is not None:
|
||||
@@ -690,10 +698,23 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
|
||||
continue
|
||||
|
||||
infos.append(info)
|
||||
# if last member of batch has different resolution, flush the batch
|
||||
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
for i in tqdm(range(0, len(infos), vae_batch_size), smoothing=1, total=len(infos) // vae_batch_size):
|
||||
batch = infos[i : i + vae_batch_size]
|
||||
batch.append(info)
|
||||
|
||||
# if number of data in batch is enough, flush the batch
|
||||
if len(batch) >= vae_batch_size:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
# iterate batches
|
||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||
images = []
|
||||
for info in batch:
|
||||
image = self.load_image(info.absolute_path)
|
||||
|
||||
Reference in New Issue
Block a user