Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-31 16:37:47 +08:00
committed by GitHub
parent 6de0051eb2
commit a979ea5a50

View File

@@ -2881,7 +2881,13 @@ def main(args):
n, m = divmod(len(sublist), device)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)])
batch_separated_list.append(split_into_batches)
if distributed_state.num_processes > 1:
templist = []
for i in range(distributed_state.num_processes):
templist.append(batch_separated_list[i :: distributed_state.num_processes])
batch_separated_list = []
for sub_batch_list in templist:
batch_separated_list.extend(sub_batch_list)
distributed_state.wait_for_everyone()
batch_data = gather_object(batch_separated_list)
del extinfo