diff --git a/train_network.py b/train_network.py index 5f2a31da..7563424b 100644 --- a/train_network.py +++ b/train_network.py @@ -1023,14 +1023,8 @@ class NetworkTrainer: # Checks if the accelerator has performed an optimization step behind the scenes # Collecting latents and caption lists from all processes - all_lists_of_latents = gather_object(latents) - all_lists_of_captions = gather_object(batch["captions"]) - all_latents = [] - all_captions = [] - for ilatents in all_lists_of_latents: - all_latents.extend(ilatents) - for icaptions in all_lists_of_captions: - all_captions.extend(icaptions) + all_latents = gather_object(latents) + all_captions = gather_object(batch["captions"]) example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1)