Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-02-08 10:53:17 +08:00
committed by GitHub
parent 03611217bf
commit b56634093c

View File

@@ -2889,7 +2889,7 @@ def main(args):
extinfo = gather_object(extinfo)
ext_separated_list_of_batches = []
if len(prompt_data_list) > 0 and distributed_state.is_main_process:
if len(prompt_data_list) > 0:
unique_extinfo = list(set(extinfo))
# splits list of prompts into sublists where BatchDataExt ext is identical
for i in range(len(unique_extinfo)):
@@ -2916,14 +2916,14 @@ def main(args):
n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
ext_separated_list_of_batches.append(split_into_batches)
#if distributed_state.num_processes > 1:
# for x in range(len(ext_separated_list_of_batches)):
# temp_list = []
# for i in range(distributed_state.num_processes):
# temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes])
# ext_separated_list_of_batches[x] = []
# for batches in temp_list:
# ext_separated_list_of_batches[x].extend(batches)
if distributed_state.num_processes > 1:
for x in range(len(ext_separated_list_of_batches)):
temp_list = []
for i in range(distributed_state.num_processes):
temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes])
ext_separated_list_of_batches[x] = []
for batches in temp_list:
ext_separated_list_of_batches[x].extend(batches)
distributed_state.wait_for_everyone()
# ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)