Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-28 05:33:34 +08:00
committed by GitHub
parent e96ea1b841
commit b1ca9d9485

View File

@@ -2441,6 +2441,7 @@ def main(args):
prompt_index = 0
global_step = 0
batch_data = []
extinfo = []
while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)):
if len(prompt_list) == 0:
# interactive
@@ -2766,10 +2767,11 @@ def main(args):
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
if len(seeds) > 0:
if len(seeds) > 1:
seed = seeds.pop(0)
if len(seeds) == 1:
seeds = None
elif len(seeds) == 1:
seed = seeds.pop(0)
seeds = None
else:
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
@@ -2847,23 +2849,48 @@ def main(args):
),
)
batch_data.append(b1)
extinfo.append(b1.ext)
global_step += 1
prompt_index += 1
batch_separated_list = []
if distributed_state.is_main_process and len(batch_data) > 0:
unique_extinfo = list(set(extinfo))
# splits list of prompts into sublists where BatchDataExt ext is identical
for i in range(len(unique_extinfo)):
templist = []
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
for index in res:
templist.append(batch_data[index])
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
if(len(split_into_batches) % distributed_state.num_processes != 0):
#Distributes last round of batches across available processes if last round of batches less than availble processes and there is more than one prompt in the last batch
sublist = []
for j in range(len(split_into_batches) % distributed_state.num_processes):
if len(split_into_batches) > 1 :
sublist.extend(split_into_batches.pop(-1))
elif len(split_into_batches) == 1 :
sublist.extend(split_into_batches.pop(-1))
listofbatches = []
n, m = divmod(len(sublist), device)
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)])
batch_separated_list.append(split_into_batches)
distributed_state.wait_for_everyone()
batch_data = gather_object(batch_data)
batch_data = gather_object(batch_separated_list)
del extinfo
if len(batch_data) > 0:
data_loader = get_batches(items=batch_data, batch_size=args.batch_size)
with distributed_state.split_between_processes(data_loader) as batch_list:
for j in range(len(batch_list)):
logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:")
logger.info(f"batch_list:")
for i in range(len(batch_list[j])):
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}\nSeed: {batch_list[j][i].base.seed}")
prev_image = process_batch(batch_list[j], highres_fix)[0]
distributed_state.wait_for_everyone()
for batch_list in batch_data:
with distributed_state.split_between_processes(batch_list) as batches:
for j in range(len(batches)):
logger.info(f"Loading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:")
logger.info(f"batch_list:")
for i in range(len(batches[j])):
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}")
prev_image = process_batch(batch_list[j], highres_fix)[0]
distributed_state.wait_for_everyone()
#for i in range(len(data_loader)):
# logger.info(f"Loading Batch {i+1} of {len(data_loader)}")
# batch_data_split.append(data_loader[i])