From 31d0059dfb0d40bb2e4720749b6f578f2d8a0473 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Tue, 15 Apr 2025 01:42:47 +0800 Subject: [PATCH] Update train_util.py --- library/train_util.py | 53 ++++++++++++++++++++----------------------- 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bfeb7082..e7456be0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5424,7 +5424,7 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict -RE_CAPTION_PROMPT = re.compile(r"(?i)__caption\|?(.+?)__") +RE_CAPTION_PROMPT = re.compile(r"(?i)__caption((\|)(.+?)?)?__") def sample_images_common( pipe_class, accelerator: Accelerator, @@ -5538,32 +5538,31 @@ def sample_images_common( prompt_dict = line_to_prompt_dict(prompt_dict) prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) - - if '__caption|' in prompt_dict.get("prompt"): + selected = "" + if '__caption' in prompt_dict.get("prompt"): match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt")) - if not example_tuple: - caption_list = match_caption.group(1).split("|") - selected = random.choice(caption_list) - prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else '__caption__') - else: - prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), '__caption__') - - if '__caption__' in prompt_dict.get("prompt") and example_tuple: - logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - - while latents_list[idx]["prompt"] == '': - idx = (idx + 1) % len(latents_list) - if idx == 0: - break - - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{latents_list[idx]["prompt"]}') - logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") - prompt_dict["height"] = latents_list[idx]["height"] - logger.info(f"Original Image Height: {prompt_dict['height']}") - prompt_dict["width"] = latents_list[idx]["width"] - logger.info(f"Original Image Width: {prompt_dict['width']}") - prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"] - idx = (idx + 1) % len(latents_list) + if match_caption is not None: + if not example_tuple: + + if match_caption.group(3) is not None: + caption_list = match_caption.group(3).split("|") + selected = random.choice(caption_list) + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else f'Astronaut riding a horse on the moon') + logger.info(f"Backup prompt: {prompt_dict.get('prompt')}") + else: + while latents_list[idx]["prompt"] == '': + idx = (idx + 1) % len(latents_list) + if idx == 0: + break + + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), f'{latents_list[idx]["prompt"]}') + #logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") + prompt_dict["height"] = latents_list[idx]["height"] + #logger.info(f"Original Image Height: {prompt_dict['height']}") + prompt_dict["width"] = latents_list[idx]["width"] + #logger.info(f"Original Image Width: {prompt_dict['width']}") + prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"] + idx = (idx + 1) % len(latents_list) prompt_dict["enum"] = i prompt_dict.pop("subset", None) @@ -5598,8 +5597,6 @@ def sample_images_common( with distributed_state.split_between_processes(prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists: if '__caption__' in prompt_dict.get("prompt"): - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'Astronaut riding a horse on the moon') - logger.info("No training prompts loaded, replacing with placeholder 'Astronaut riding a horse on the moon' prompt.") sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet )