shuffle at debug_dataset

This commit is contained in:
Kohya S
2023-03-01 21:12:33 +09:00
parent ed19a92bbe
commit 089a63c573

View File

@@ -1080,7 +1080,10 @@ def debug_dataset(train_dataset, show_input_ids=False):
train_dataset.set_current_epoch(1) train_dataset.set_current_epoch(1)
k = 0 k = 0
for i, example in enumerate(train_dataset): indices = list(range(len(train_dataset)))
random.shuffle(indices)
for i, idx in enumerate(indices):
example = train_dataset[idx]
if example['latents'] is not None: if example['latents'] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}") print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])): for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):