diff --git a/train_network.py b/train_network.py index 2554b230..5f2a31da 100644 --- a/train_network.py +++ b/train_network.py @@ -1022,7 +1022,16 @@ class NetworkTrainer: keys_scaled, mean_norm, maximum_norm = None, None, None # Checks if the accelerator has performed an optimization step behind the scenes - example_tuple = (latents, batch["captions"]) + # Collecting latents and caption lists from all processes + all_lists_of_latents = gather_object(latents) + all_lists_of_captions = gather_object(batch["captions"]) + all_latents = [] + all_captions = [] + for ilatents in all_lists_of_latents: + all_latents.extend(ilatents) + for icaptions in all_lists_of_captions: + all_captions.extend(icaptions) + example_tuple = (all_latents, all_captions) if accelerator.sync_gradients: progress_bar.update(1) global_step += 1