mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Update accel_sdxl_gen_img.py
This commit is contained in:
@@ -2876,39 +2876,40 @@ def main(args):
|
||||
logger.info(f"batch_data line 2878: {len(batch_data)}")
|
||||
batch_separated_list = []
|
||||
logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}")
|
||||
if distributed_state.is_main_process and len(batch_data) > 0:
|
||||
unique_extinfo = list(set(extinfo))
|
||||
logger.info(f"batch_data line 2880: {len(batch_data)}")
|
||||
# splits list of prompts into sublists where BatchDataExt ext is identical
|
||||
for i in range(len(unique_extinfo)):
|
||||
templist = []
|
||||
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
|
||||
for index in res:
|
||||
templist.append(batch_data[index])
|
||||
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
|
||||
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
||||
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
|
||||
sublist = []
|
||||
for j in range(len(split_into_batches) % distributed_state.num_processes):
|
||||
if len(split_into_batches) > 1 :
|
||||
sublist.extend(split_into_batches.pop(-1))
|
||||
elif len(split_into_batches) == 1 :
|
||||
sublist.extend(split_into_batches.pop(-1))
|
||||
listofbatches = []
|
||||
n, m = divmod(len(sublist), device)
|
||||
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)])
|
||||
batch_separated_list.append(split_into_batches)
|
||||
logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}")
|
||||
if distributed_state.num_processes > 1:
|
||||
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
|
||||
if len(batch_data) > 0:
|
||||
if distributed_state.is_main_process:
|
||||
unique_extinfo = list(set(extinfo))
|
||||
logger.info(f"batch_data line 2880: {len(batch_data)}")
|
||||
# splits list of prompts into sublists where BatchDataExt ext is identical
|
||||
for i in range(len(unique_extinfo)):
|
||||
templist = []
|
||||
for i in range(distributed_state.num_processes):
|
||||
templist.append(batch_separated_list[i :: distributed_state.num_processes])
|
||||
logger.info(f"templist: {len(templist)}")
|
||||
batch_separated_list = []
|
||||
for sub_batch_list in templist:
|
||||
batch_separated_list.extend(sub_batch_list)
|
||||
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
|
||||
for index in res:
|
||||
templist.append(batch_data[index])
|
||||
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
|
||||
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
||||
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
|
||||
sublist = []
|
||||
for j in range(len(split_into_batches) % distributed_state.num_processes):
|
||||
if len(split_into_batches) > 1 :
|
||||
sublist.extend(split_into_batches.pop(-1))
|
||||
elif len(split_into_batches) == 1 :
|
||||
sublist.extend(split_into_batches.pop(-1))
|
||||
listofbatches = []
|
||||
n, m = divmod(len(sublist), device)
|
||||
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)])
|
||||
batch_separated_list.append(split_into_batches)
|
||||
logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}")
|
||||
if distributed_state.num_processes > 1:
|
||||
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
|
||||
templist = []
|
||||
for i in range(distributed_state.num_processes):
|
||||
templist.append(batch_separated_list[i :: distributed_state.num_processes])
|
||||
logger.info(f"templist: {len(templist)}")
|
||||
batch_separated_list = []
|
||||
for sub_batch_list in templist:
|
||||
batch_separated_list.extend(sub_batch_list)
|
||||
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
|
||||
distributed_state.wait_for_everyone()
|
||||
batch_data = gather_object(batch_separated_list)
|
||||
logger.info(f"batch_data line 2911: {len(batch_data)}")
|
||||
|
||||
Reference in New Issue
Block a user