diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index e2fe2eda..6d28b945 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2916,14 +2916,16 @@ def main(args): n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) - #if distributed_state.num_processes > 1: - # for x in range(len(ext_separated_list_of_batches)): - # temp_list = [] - # for i in range(distributed_state.num_processes): - # temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - # ext_separated_list_of_batches[x] = [] - # for batches in temp_list: - # ext_separated_list_of_batches[x].extend(batches) + logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") + if distributed_state.num_processes > 1: + for x in range(len(ext_separated_list_of_batches)): + temp_list = [] + for i in range(distributed_state.num_processes): + temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) + holder = [] + for batches in temp_list: + holder.extend(batches) + ext_separated_list_of_batches[x] = holder[:] distributed_state.wait_for_everyone() # ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)