mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Update accel_sdxl_gen_img.py
This commit is contained in:
@@ -2441,6 +2441,7 @@ def main(args):
|
||||
prompt_index = 0
|
||||
global_step = 0
|
||||
batch_data = []
|
||||
extinfo = []
|
||||
while args.interactive or (prompt_index < len(prompt_list) and (not distributed_state.is_main_process or distributed_state.num_processes == 1)):
|
||||
if len(prompt_list) == 0:
|
||||
# interactive
|
||||
@@ -2766,10 +2767,11 @@ def main(args):
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
if len(seeds) > 0:
|
||||
if len(seeds) > 1:
|
||||
seed = seeds.pop(0)
|
||||
if len(seeds) == 1:
|
||||
seeds = None
|
||||
elif len(seeds) == 1:
|
||||
seed = seeds.pop(0)
|
||||
seeds = None
|
||||
else:
|
||||
if predefined_seeds is not None:
|
||||
if len(predefined_seeds) > 0:
|
||||
@@ -2847,23 +2849,48 @@ def main(args):
|
||||
),
|
||||
)
|
||||
batch_data.append(b1)
|
||||
extinfo.append(b1.ext)
|
||||
global_step += 1
|
||||
|
||||
prompt_index += 1
|
||||
batch_separated_list = []
|
||||
if distributed_state.is_main_process and len(batch_data) > 0:
|
||||
unique_extinfo = list(set(extinfo))
|
||||
# splits list of prompts into sublists where BatchDataExt ext is identical
|
||||
for i in range(len(unique_extinfo)):
|
||||
templist = []
|
||||
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
|
||||
for index in res:
|
||||
templist.append(batch_data[index])
|
||||
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
|
||||
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
||||
#Distributes last round of batches across available processes if last round of batches less than availble processes and there is more than one prompt in the last batch
|
||||
sublist = []
|
||||
for j in range(len(split_into_batches) % distributed_state.num_processes):
|
||||
if len(split_into_batches) > 1 :
|
||||
sublist.extend(split_into_batches.pop(-1))
|
||||
elif len(split_into_batches) == 1 :
|
||||
sublist.extend(split_into_batches.pop(-1))
|
||||
listofbatches = []
|
||||
n, m = divmod(len(sublist), device)
|
||||
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(device)])
|
||||
batch_separated_list.append(split_into_batches)
|
||||
|
||||
distributed_state.wait_for_everyone()
|
||||
batch_data = gather_object(batch_data)
|
||||
|
||||
batch_data = gather_object(batch_separated_list)
|
||||
del extinfo
|
||||
|
||||
if len(batch_data) > 0:
|
||||
data_loader = get_batches(items=batch_data, batch_size=args.batch_size)
|
||||
with distributed_state.split_between_processes(data_loader) as batch_list:
|
||||
for j in range(len(batch_list)):
|
||||
logger.info(f"Loading batch {j+1}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:")
|
||||
logger.info(f"batch_list:")
|
||||
for i in range(len(batch_list[j])):
|
||||
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[j][i].base.prompt}\nNegative Prompt: {batch_list[j][i].base.negative_prompt}\nSeed: {batch_list[j][i].base.seed}")
|
||||
prev_image = process_batch(batch_list[j], highres_fix)[0]
|
||||
|
||||
distributed_state.wait_for_everyone()
|
||||
for batch_list in batch_data:
|
||||
with distributed_state.split_between_processes(batch_list) as batches:
|
||||
for j in range(len(batches)):
|
||||
logger.info(f"Loading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:")
|
||||
logger.info(f"batch_list:")
|
||||
for i in range(len(batches[j])):
|
||||
logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batches[j][i].base.prompt}\nNegative Prompt: {batches[j][i].base.negative_prompt}\nSeed: {batches[j][i].base.seed}")
|
||||
prev_image = process_batch(batch_list[j], highres_fix)[0]
|
||||
|
||||
distributed_state.wait_for_everyone()
|
||||
#for i in range(len(data_loader)):
|
||||
# logger.info(f"Loading Batch {i+1} of {len(data_loader)}")
|
||||
# batch_data_split.append(data_loader[i])
|
||||
|
||||
Reference in New Issue
Block a user