Update accel_sdxl_gen_img.py

This commit is contained in:
DKnight54
2025-01-26 03:49:59 +08:00
committed by GitHub
parent a832925553
commit 2328128b1c

View File

@@ -2454,10 +2454,6 @@ def main(args):
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
raw_prompts = handle_dynamic_prompt_variants(raw_prompt, args.images_per_prompt)
for pi in range(distributed_state.num_processes):
if pi == distributed_state.local_process_index:
logger.info(f"Total raw prompts: {len(raw_prompts)} for {distributed_state.local_process_index}")
# repeat prompt
for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)):
@@ -2846,14 +2842,9 @@ def main(args):
prompt_index += 1
distributed_state.wait_for_everyone()
batch_data = gather_object(batch_data)
for pi in range(distributed_state.num_processes):
if pi == distributed_state.local_process_index:
logger.info(f"Total prompts: {len(batch_data)} for {distributed_state.local_process_index}")
if len(batch_data) > 0:
data_loader = get_batches(items=batch_data, batch_size=args.batch_size)
logger.info(f"Total batches: {len(data_loader)}")
batch_data_split = [] #[batch_data[i:i+3] for i in range(0, len(my_list), 3)]
batch_index = []
with distributed_state.split_between_processes(data_loader) as batch_list:
for j in range(len(batch_list)):
logger.info(f"Loading batch {j}/{len(batch_list)} of {len(batch_list[j])} prompts onto device {distributed_state.local_process_index}:")