Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-02-08 18:17:51 +08:00
committed by GitHub
parent 6e30e5726b
commit 68a039f57b

View File

@@ -2916,14 +2916,16 @@ def main(args):
n, m = divmod(len(sublist), distributed_state.num_processes)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
ext_separated_list_of_batches.append(split_into_batches)
#if distributed_state.num_processes > 1:
# for x in range(len(ext_separated_list_of_batches)):
# temp_list = []
# for i in range(distributed_state.num_processes):
# temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes])
# ext_separated_list_of_batches[x] = []
# for batches in temp_list:
# ext_separated_list_of_batches[x].extend(batches)
logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")
if distributed_state.num_processes > 1:
for x in range(len(ext_separated_list_of_batches)):
temp_list = []
for i in range(distributed_state.num_processes):
temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes])
holder = []
for batches in temp_list:
holder.extend(batches)
ext_separated_list_of_batches[x] = holder[:]
distributed_state.wait_for_everyone()
# ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)