From 4bb58fbbfb0a1e090c6a106b0a096f37b7ecf141 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Mon, 14 Apr 2025 01:37:59 +0800 Subject: [PATCH] Update train_util.py --- library/train_util.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index b5fb103e..bfeb7082 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5424,7 +5424,7 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict - +RE_CAPTION_PROMPT = re.compile(r"(?i)__caption\|?(.+?)__") def sample_images_common( pipe_class, accelerator: Accelerator, @@ -5539,6 +5539,15 @@ def sample_images_common( prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) + if '__caption|' in prompt_dict.get("prompt"): + match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt")) + if not example_tuple: + caption_list = match_caption.group(1).split("|") + selected = random.choice(caption_list) + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else '__caption__') + else: + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), '__caption__') + if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}")