diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 3b25d2c5..d72753c9 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2889,7 +2889,7 @@ def main(args): extinfo = gather_object(extinfo) ext_separated_list_of_batches = [] - if len(prompt_data_list) > 0 and distributed_state.is_main_process: + if len(prompt_data_list) > 0: unique_extinfo = list(set(extinfo)) # splits list of prompts into sublists where BatchDataExt ext is identical for i in range(len(unique_extinfo)): @@ -2916,14 +2916,14 @@ def main(args): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - #if distributed_state.num_processes > 1: - # for x in range(len(ext_separated_list_of_batches)): - # temp_list = [] - # for i in range(distributed_state.num_processes): - # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - # ext_separated_list_of_batches[x] = [] - # for batches in temp_list: - # ext_separated_list_of_batches[x].extend(batches) + if distributed_state.num_processes > 1: + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] + for i in range(distributed_state.num_processes): + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + ext_separated_list_of_batches[x] = [] + for batches in temp_list: + ext_separated_list_of_batches[x].extend(batches) distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)