diff --git a/library/train_util.py b/library/train_util.py index b5fb103e..bfeb7082 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5424,7 +5424,7 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict - +RE_CAPTION_PROMPT = re.compile(r"(?i)__caption\|?(.+?)__") def sample_images_common( pipe_class, accelerator: Accelerator, @@ -5539,6 +5539,15 @@ def sample_images_common( prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) + if '__caption|' in prompt_dict.get("prompt"): + match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt")) + if not example_tuple: + caption_list = match_caption.group(1).split("|") + selected = random.choice(caption_list) + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else '__caption__') + else: + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), '__caption__') + if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}")