From 089a63c57316566749d9415257c9172e15cbe672 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 1 Mar 2023 21:12:33 +0900 Subject: [PATCH] shuffle at debug_dataset --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index b191604c..c377a56a 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1080,7 +1080,10 @@ def debug_dataset(train_dataset, show_input_ids=False): train_dataset.set_current_epoch(1) k = 0 - for i, example in enumerate(train_dataset): + indices = list(range(len(train_dataset))) + random.shuffle(indices) + for i, idx in enumerate(indices): + example = train_dataset[idx] if example['latents'] is not None: print(f"sample has latents from npz file: {example['latents'].size()}") for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):