From 8967e2f0144ac8c158862162d5a5842fdff754c5 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Sun, 2 Feb 2025 00:50:03 +0800 Subject: [PATCH] Update accel_sdxl_gen_img.py --- accel_sdxl_gen_img.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/accel_sdxl_gen_img.py b/accel_sdxl_gen_img.py index 6bfa14aa..15e9ebd2 100644 --- a/accel_sdxl_gen_img.py +++ b/accel_sdxl_gen_img.py @@ -2469,6 +2469,8 @@ def main(args): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] + if pi == 0 or len(raw_prompts) > 1: + logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") for parg in prompt_args[1:]: try: @@ -2481,8 +2483,7 @@ def main(args): except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") - if pi == 0 or len(raw_prompts) > 1: - logger.info(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + if pi == 0: # parse prompt: if prompt is not changed, skip parsing @@ -2904,23 +2905,12 @@ def main(args): if distributed_state.num_processes > 1: for x in range(len(ext_separated_list_of_batches)): temp_list = [] - logger.info(f"start: ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") for i in range(distributed_state.num_processes): temp_list.append(ext_separated_list_of_batches[x][i :: distributed_state.num_processes]) ext_separated_list_of_batches[x] = [] for batches in temp_list: ext_separated_list_of_batches[x].extend(batches) - logger.info(f"templist: {len(temp_list)}, ext_separated_list_of_batches[x]: {len(ext_separated_list_of_batches[x])}") - - logger.info(f"ext_separated_list_of_batches: {len(ext_separated_list_of_batches)}") - count_prompts = 0 - for sub_batch_list in ext_separated_list_of_batches: - logger.info(f" sub_batch_list: {len(sub_batch_list)}") - for batches in sub_batch_list: - logger.info(f" batches: {len(batches)}") - count_prompts += len(batches) - logger.info(f"count_prompts: {count_prompts}") - + distributed_state.wait_for_everyone() ext_separated_list_of_batches = gather_object(ext_separated_list_of_batches) del extinfo