Update train_util.py

This commit is contained in:
DKnight54
2025-04-14 01:37:59 +08:00
committed by GitHub
parent eb2d9abff6
commit 4bb58fbbfb

View File

@@ -5424,7 +5424,7 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict
RE_CAPTION_PROMPT = re.compile(r"(?i)__caption\|?(.+?)__")
def sample_images_common(
pipe_class,
accelerator: Accelerator,
@@ -5539,6 +5539,15 @@ def sample_images_common(
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
if '__caption|' in prompt_dict.get("prompt"):
match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt"))
if not example_tuple:
caption_list = match_caption.group(1).split("|")
selected = random.choice(caption_list)
prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else '__caption__')
else:
prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), '__caption__')
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
logger.info(f"Original prompt: {prompt_dict.get('prompt')}")