diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index d8c8f044..0f2b3bd0 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2902,14 +2902,14 @@ def main(args): res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]] for index in res: templist.append(prompt_data_list[index]) - ''' + if distributed_state.num_processes > 1: resorted_list = [] for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) for list_of_prompts in resorted_list: templist.extend(list_of_prompts) - ''' + split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() sublist = [] @@ -2926,6 +2926,7 @@ def main(args): elif len(split_into_batches) == 1 : sublist.extend(split_into_batches.pop(-1)) split_into_batches = [] + # sublist = sorted(sublist, key=lambda x: x.global_count) n, m = divmod(len(sublist), distributed_state.num_processes) split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])