Update train_network.py

This commit is contained in:
DKnight54
2025-01-31 15:07:18 +08:00
committed by GitHub
parent e723e457ad
commit 6de0051eb2

View File

@@ -1022,7 +1022,16 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = None, None, None
# Checks if the accelerator has performed an optimization step behind the scenes
example_tuple = (latents, batch["captions"])
# Collecting latents and caption lists from all processes
all_lists_of_latents = gather_object(latents)
all_lists_of_captions = gather_object(batch["captions"])
all_latents = []
all_captions = []
for ilatents in all_lists_of_latents:
all_latents.extend(ilatents)
for icaptions in all_lists_of_captions:
all_captions.extend(icaptions)
example_tuple = (all_latents, all_captions)
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1