Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-02-09 01:21:05 +08:00
committed by GitHub
parent 58d6434587
commit 0a74083908

View File

@@ -2902,14 +2902,14 @@ def main(args):
res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]]
for index in res:
templist.append(prompt_data_list[index])
'''
if distributed_state.num_processes > 1:
resorted_list = []
for i in range(distributed_state.num_processes):
resorted_list.append(templist[i :: distributed_state.num_processes])
for list_of_prompts in resorted_list:
templist.extend(list_of_prompts)
'''
split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy()
sublist = []
@@ -2926,6 +2926,7 @@ def main(args):
elif len(split_into_batches) == 1 :
sublist.extend(split_into_batches.pop(-1))
split_into_batches = []
# sublist = sorted(sublist, key=lambda x: x.global_count)
n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])