diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 62a245ed..d6af05ad 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2908,51 +2908,14 @@ def main(args): for i in range(distributed_state.num_processes): resorted_list.append(templist[i :: distributed_state.num_processes]) templist = [] - for list_of_prompts in resorted_list: - templist.extend(list_of_prompts) - - split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() - - sublist = [] - if(len(split_into_batches) % distributed_state.num_processes != 0): - #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - popnum = len(split_into_batches) % distributed_state.num_processes + for list_of_prompts in resorted_list: + templist.extend(get_batches(items=list_of_prompts, batch_size=args.batch_size).copy()) + split_into_batches = templist else: - #force distribution check on last round of batches - popnum = distributed_state.num_processes - - for j in range(popnum): - if len(split_into_batches) > 1 : - sublist.extend(split_into_batches.pop(-1)) - elif len(split_into_batches) == 1 : - sublist.extend(split_into_batches.pop(-1)) - split_into_batches = [] - sublist = sorted(sublist, key=lambda x: x.global_count) - resorted_list = [] - for i in range(distributed_state.num_processes): - resorted_list.append(sublist[i :: distributed_state.num_processes]) - sublist = [] - for list_of_prompts in resorted_list: - sublist.extend(list_of_prompts) + split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy() - n, m = divmod(len(sublist), distributed_state.num_processes) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) - ext_separated_list_of_batches.append(split_into_batches) - ''' - logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}") - if distributed_state.num_processes > 1: - for x in range(len(ext_separated_list_of_batches)): - temp_list = [] - for i in range(distributed_state.num_processes): - temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) - holder = [] - for batches in temp_list: - holder.extend(batches) - ext_separated_list_of_batches[x] = holder[:] - ''' - distributed_state.wait_for_everyone() - + if distributed_state.is_main_process: batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n" for x in range(len(ext_separated_list_of_batches)): @@ -2962,7 +2925,7 @@ def main(args): for z in range(len(ext_separated_list_of_batches[x][y])): batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n" logger.info(batchlogstr) - + distributed_state.wait_for_everyone() ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo #logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")