From b56634093ce9f31c2d857646ff55a712b5255f06 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 8 Feb 2025 10:53:17 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 3b25d2c5..d72753c9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2889,7 +2889,7 @@ def main(args): extinfo = gather_object(extinfo) ext_separated_list_of_batches = [] - if len(prompt_data_list) > 0 and distributed_state.is_main_process: + if len(prompt_data_list) > 0: unique_extinfo = list(set(extinfo)) # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): @@ -2916,14 +2916,14 @@ def main(args): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - #if distributed_state.num_processes > 1: - # for x in range(len(ext_separated_list_of_batches)): - # temp_list = [] - # for i in range(distributed_state.num_processes): - # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - # ext_separated_list_of_batches[x] = [] - # for batches in temp_list: - # ext_separated_list_of_batches[x].extend(batches) + if distributed_state.num_processes > 1: + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] + for i in range(distributed_state.num_processes): + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + ext_separated_list_of_batches[x] = [] + for batches in temp_list: + ext_separated_list_of_batches[x].extend(batches) distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)