From cc6266db19d395d9b8344bb2a5254429a63ea93a Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sat, 1 Feb 2025 23:36:54 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 5ed67d2f..6bfa14aa 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2883,21 +2883,25 @@ def main(args): for index in res: templist.append(prompt_data_list[index]) split_into_batches = get_batches(items=templist, batch_size=args.batch_size) + sublist = [] if(len(split_into_batches) % distributed_state.num_processes != 0): #Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch - sublist = [] - for j in range(len(split_into_batches) % distributed_state.num_processes): - if len(split_into_batches) > 1 : - sublist.extend(split_into_batches.pop(-1)) - elif len(split_into_batches) == 1 : - sublist.extend(split_into_batches.pop(-1)) - split_into_batches = [] - n, m = divmod(len(sublist), distributed_state.num_processes) - split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) + popnum = (len(split_into_batches) % distributed_state.num_processes + else: + #force distribution check on last round of batches + popnum = distributed_state.num_processes + + for j in range(popnum): + if len(split_into_batches) > 1 : + sublist.extend(split_into_batches.pop(-1)) + elif len(split_into_batches) == 1 : + sublist.extend(split_into_batches.pop(-1)) + split_into_batches = [] + + n, m = divmod(len(sublist), distributed_state.num_processes) + split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)]) ext_separated_list_of_batches.append(split_into_batches) if distributed_state.num_processes > 1: - - for x in range(len(ext_separated_list_of_batches)): temp_list = [] logger.info(f"start: ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}")