diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 316c9d0d..78039b2a 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2881,7 +2881,13 @@ def main(args): n, m = divmod(len(sublist), device) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)]) batch_separated_list.append(split_into_batches) - + if distributed_state.num_processes > 1: + templist = [] + for i in range(distributed_state.num_processes): + templist.append(batch_separated_list[i :: distributed_state.num_processes]) + batch_separated_list = [] + for sub_batch_list in templist: + batch_separated_list.extend(sub_batch_list) distributed_state.wait_for_everyone() batch_data = gather_object(batch_separated_list) del extinfo