diff --git a/library/train_util.py b/library/train_util.py index b191604c..c377a56a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1080,7 +1080,10 @@ def debug_dataset(train_dataset, show_input_ids=False): train_dataset.set_current_epoch(1) k = 0 - for i, example in enumerate(train_dataset): + indices = list(range(len(train_dataset))) + random.shuffle(indices) + for i, idx in enumerate(indices): + example = train_dataset[idx] if example['latents'] is not None: print(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):