mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 00:32:25 +00:00
Update accel_sdxl_gen_img.py
This commit is contained in:
@@ -2420,9 +2420,6 @@ def main(args):
|
||||
else:
|
||||
fln = f"im_{ds_str}_{globalcount:02d}_{ts_str}_{highres_prefix}{i:03d}_{seed}.png"
|
||||
|
||||
#logger.info(f"Saving Image: {fln}:\nPrompt: {prompt}")
|
||||
if negative_prompt is not None:
|
||||
# logger.info(f"Negative Prompt: {negative_prompt}\n")
|
||||
image.save(os.path.join(args.outdir, fln), pnginfo=metadata)
|
||||
|
||||
if not args.no_preview and not highres_1st and args.interactive:
|
||||
@@ -2443,7 +2440,7 @@ def main(args):
|
||||
# 画像生成のプロンプトが一周するまでのループ
|
||||
prompt_index = 0
|
||||
global_step = 0
|
||||
batch_data = []
|
||||
prompt_data_list = []
|
||||
extinfo = []
|
||||
while args.interactive or (prompt_index < len(prompt_list) and distributed_state.is_main_process):
|
||||
if len(prompt_list) == 0:
|
||||
@@ -2867,26 +2864,24 @@ def main(args):
|
||||
num_sub_prompts,
|
||||
),
|
||||
)
|
||||
batch_data.append(b1)
|
||||
prompt_data_list.append(b1)
|
||||
extinfo.append(b1.ext)
|
||||
global_step += 1
|
||||
|
||||
prompt_index += 1
|
||||
batch_data = gather_object(batch_data)
|
||||
prompt_data_list = gather_object(prompt_data_list)
|
||||
extinfo = gather_object(extinfo)
|
||||
logger.info(f"batch_data line 2878: {len(batch_data)}")
|
||||
batch_separated_list = []
|
||||
logger.info(f"Device {distributed_state.device}, distributed_state.is_main_process 2878: {distributed_state.is_main_process}")
|
||||
if len(batch_data) > 0 and distributed_state.is_main_process:
|
||||
|
||||
ext_separated_list_of_batches = []
|
||||
if len(prompt_data_list) > 0 and distributed_state.is_main_process:
|
||||
unique_extinfo = list(set(extinfo))
|
||||
logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {len(unique_extinfo)}")
|
||||
logger.info(f"Device {distributed_state.device}, prompt_data_list line 2880: {len(prompt_data_list)}, len(unique_extinfo): {len(unique_extinfo)}")
|
||||
# splits list of prompts into sublists where BatchDataExt ext is identical
|
||||
for i in range(len(unique_extinfo)):
|
||||
logger.info(f"Device {distributed_state.device}, batch_data line 2880: {len(batch_data)}, len(unique_extinfo): {i}")
|
||||
templist = []
|
||||
res = [j for j, val in enumerate(batch_data) if val.ext == unique_extinfo[i]]
|
||||
res = [j for j, val in enumerate(prompt_data_list) if val.ext == unique_extinfo[i]]
|
||||
for index in res:
|
||||
templist.append(batch_data[index])
|
||||
templist.append(prompt_data_list[index])
|
||||
split_into_batches = get_batches(items=templist, batch_size=args.batch_size)
|
||||
if(len(split_into_batches) % distributed_state.num_processes != 0):
|
||||
#Distributes last round of batches across available processes if last round of batches less than available processes and there is more than one prompt in the last batch
|
||||
@@ -2899,28 +2894,34 @@ def main(args):
|
||||
split_into_batches = []
|
||||
n, m = divmod(len(sublist), distributed_state.num_processes)
|
||||
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])
|
||||
batch_separated_list.append(split_into_batches)
|
||||
logger.info(f"batch_separated_list line 2901: {len(batch_separated_list)}, {distributed_state.num_processes}")
|
||||
ext_separated_list_of_batches.append(split_into_batches)
|
||||
if distributed_state.num_processes > 1:
|
||||
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
|
||||
|
||||
temp_list = []
|
||||
for ext_batch in batch_separated_list:
|
||||
|
||||
for x in range(len(ext_separated_list_of_batches)):
|
||||
temp_list = []
|
||||
for i in range(distributed_state.num_processes):
|
||||
temp_list.append(ext_batch[i :: distributed_state.num_processes])
|
||||
logger.info(f"templist: {len(temp_list)}")
|
||||
batch_separated_list = []
|
||||
for sub_batch_list in temp_list:
|
||||
batch_separated_list.append(sub_batch_list)
|
||||
logger.info(f"batch_separated_list: {len(batch_separated_list)}")
|
||||
logger.info(f"sub_batch_list: {len(sub_batch_list)}")
|
||||
temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes])
|
||||
ext_separated_list_of_batches[x] = []
|
||||
for batches in temp_list:
|
||||
ext_separated_list_of_batches[x].extend(batches)
|
||||
logger.info(f"templist: {len(temp_list)}, ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}")
|
||||
|
||||
logger.info(f"ext_separated_list_of_batches: {len(ext_separated_list_of_batches)}")
|
||||
count_prompts = 0
|
||||
for sub_batch_list in ext_separated_list_of_batches:
|
||||
logger.info(f" sub_batch_list: {len(sub_batch_list)}")
|
||||
for batches in sub_batch_list:
|
||||
logger.info(f" batches: {len(batches)}")
|
||||
count_prompts += len(batches)
|
||||
logger.info(f"count_prompts: {count_prompts}")
|
||||
|
||||
distributed_state.wait_for_everyone()
|
||||
batch_separated_list = gather_object(batch_separated_list)
|
||||
logger.info(f"batch_data line 2911: {len(batch_data)}")
|
||||
ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches)
|
||||
del extinfo
|
||||
|
||||
if len(batch_separated_list) > 0:
|
||||
for batch_list in batch_separated_list:
|
||||
if len(ext_separated_list_of_batches) > 0:
|
||||
for batch_list in ext_separated_list_of_batches:
|
||||
with distributed_state.split_between_processes(batch_list) as batches:
|
||||
for j in range(len(batches)):
|
||||
logger.info(f"\nLoading batch {j+1}/{len(batches)} of {len(batches[j])} prompts onto device {distributed_state.local_process_index}:")
|
||||
@@ -2932,17 +2933,17 @@ def main(args):
|
||||
distributed_state.wait_for_everyone()
|
||||
#for i in range(len(data_loader)):
|
||||
# logger.info(f"Loading Batch {i+1} of {len(data_loader)}")
|
||||
# batch_data_split.append(data_loader[i])
|
||||
# prompt_data_list_split.append(data_loader[i])
|
||||
# if (i+1) % distributed_state.num_processes != 0 and (i+1) != len(data_loader):
|
||||
# continue
|
||||
# with torch.no_grad():
|
||||
# with distributed_state.split_between_processes(batch_data_split) as batch_list:
|
||||
# with distributed_state.split_between_processes(prompt_data_list_split) as batch_list:
|
||||
# logger.info(f"Loading batch of {len(batch_list[0])} prompts onto device {distributed_state.local_process_index}:")
|
||||
# logger.info(f"batch_list:")
|
||||
# for i in range(len(batch_list[0])):
|
||||
# logger.info(f"Device {distributed_state.device}: Prompt {i+1}: {batch_list[0][i].base.prompt}")
|
||||
# prev_image = process_batch(batch_list[0], highres_fix)[0]
|
||||
# batch_data_split.clear()
|
||||
# prompt_data_list_split.clear()
|
||||
# distributed_state.wait_for_everyone()
|
||||
logger.info("done!")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user