From 68a039f57bb5b803dc23e755c92b451c13fdd55a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 18:17:51 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e2fe2eda..6d28b945 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,14 +2916,16 @@ def main(args): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - #if distributed_state.num_processes > 1: - # for x in range(len(ext_separated_list_of_batches)): - # temp_list = [] - # for i in range(distributed_state.num_processes): - # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - # ext_separated_list_of_batches[x] = [] - # for batches in temp_list: - # ext_separated_list_of_batches[x].extend(batches) + logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") + if distributed_state.num_processes > 1: + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] + for i in range(distributed_state.num_processes): + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + holder = [] + for batches in temp_list: + holder.extend(batches) + ext_separated_list_of_batches[x] = holder[:] distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)