Update train_network.py

This commit is contained in:
DKnight54
2025-01-31 21:47:15 +08:00
committed by GitHub
parent 216596719b
commit 2cdaa33147

View File

@@ -1023,14 +1023,8 @@ class NetworkTrainer:
# Checks if the accelerator has performed an optimization step behind the scenes
# Collecting latents and caption lists from all processes
all_lists_of_latents = gather_object(latents)
all_lists_of_captions = gather_object(batch["captions"])
all_latents = []
all_captions = []
for ilatents in all_lists_of_latents:
all_latents.extend(ilatents)
for icaptions in all_lists_of_captions:
all_captions.extend(icaptions)
all_latents = gather_object(latents)
all_captions = gather_object(batch["captions"])
example_tuple = (all_latents, all_captions)
if accelerator.sync_gradients:
progress_bar.update(1)