mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Update accel_sdxl_gen_img.py
This commit is contained in:
@@ -2908,51 +2908,14 @@ def main(args):
|
|||||||
for i in range(distributed_state.num_processes):
|
for i in range(distributed_state.num_processes):
|
||||||
resorted_list.append(templist[i :: distributed_state.num_processes])
|
resorted_list.append(templist[i :: distributed_state.num_processes])
|
||||||
templist = []
|
templist = []
|
||||||
for list_of_prompts in resorted_list:
|
for list_of_prompts in resorted_list:
|
||||||
templist.extend(list_of_prompts)
|
templist.extend(get_batches(items=list_of_prompts, batch_size=args.batch_size).copy())
|
||||||
|
split_into_batches = templist
|
||||||
split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy()
|
|
||||||
|
|
||||||
sublist = []
|
|
||||||
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
|
||||||
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
|
|
||||||
popnum = len(split_into_batches) % distributed_state.num_processes
|
|
||||||
else:
|
else:
|
||||||
#force distribution check on last round of batches
|
split_into_batches = get_batches(items=templist, batch_size=args.batch_size).copy()
|
||||||
popnum = distributed_state.num_processes
|
|
||||||
|
|
||||||
for j in range(popnum):
|
|
||||||
if len(split_into_batches) > 1 :
|
|
||||||
sublist.extend(split_into_batches.pop(-1))
|
|
||||||
elif len(split_into_batches) == 1 :
|
|
||||||
sublist.extend(split_into_batches.pop(-1))
|
|
||||||
split_into_batches = []
|
|
||||||
sublist = sorted(sublist, key=lambda x: x.global_count)
|
|
||||||
resorted_list = []
|
|
||||||
for i in range(distributed_state.num_processes):
|
|
||||||
resorted_list.append(sublist[i :: distributed_state.num_processes])
|
|
||||||
sublist = []
|
|
||||||
for list_of_prompts in resorted_list:
|
|
||||||
sublist.extend(list_of_prompts)
|
|
||||||
|
|
||||||
n, m = divmod(len(sublist), distributed_state.num_processes)
|
|
||||||
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
|
|
||||||
|
|
||||||
ext_separated_list_of_batches.append(split_into_batches)
|
ext_separated_list_of_batches.append(split_into_batches)
|
||||||
'''
|
|
||||||
logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")
|
|
||||||
if distributed_state.num_processes > 1:
|
|
||||||
for x in range(len(ext_separated_list_of_batches)):
|
|
||||||
temp_list = []
|
|
||||||
for i in range(distributed_state.num_processes):
|
|
||||||
temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes])
|
|
||||||
holder = []
|
|
||||||
for batches in temp_list:
|
|
||||||
holder.extend(batches)
|
|
||||||
ext_separated_list_of_batches[x] = holder[:]
|
|
||||||
'''
|
|
||||||
distributed_state.wait_for_everyone()
|
|
||||||
|
|
||||||
if distributed_state.is_main_process:
|
if distributed_state.is_main_process:
|
||||||
batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n"
|
batchlogstr = "Running through ext_separated_list_of_batches Before Gather:\n"
|
||||||
for x in range(len(ext_separated_list_of_batches)):
|
for x in range(len(ext_separated_list_of_batches)):
|
||||||
@@ -2962,7 +2925,7 @@ def main(args):
|
|||||||
for z in range(len(ext_separated_list_of_batches[x][y])):
|
for z in range(len(ext_separated_list_of_batches[x][y])):
|
||||||
batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n"
|
batchlogstr += f" Image {z} of {len(ext_separated_list_of_batches[x][y])} break: {ext_separated_list_of_batches[x][y][z].global_count}\n"
|
||||||
logger.info(batchlogstr)
|
logger.info(batchlogstr)
|
||||||
|
distributed_state.wait_for_everyone()
|
||||||
ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)
|
ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)
|
||||||
del extinfo
|
del extinfo
|
||||||
#logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")
|
#logger.info(f"\nDevice {distributed_state.device}: {ext_separated_list_of_batches}")
|
||||||
|
|||||||
Reference in New Issue
Block a user