diff --git a/library/train_util.py b/library/train_util.py index bfeb7082..e7456be0 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5424,7 +5424,7 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict -RE_CAPTION_PROMPT = re.compile(r"(?i)__caption\|?(.+?)__") +RE_CAPTION_PROMPT = re.compile(r"(?i)__caption((\|)(.+?)?)?__") def sample_images_common( pipe_class, accelerator: Accelerator, @@ -5538,32 +5538,31 @@ def sample_images_common( prompt_dict = line_to_prompt_dict(prompt_dict) prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) - - if '__caption|' in prompt_dict.get("prompt"): + selected = "" + if '__caption' in prompt_dict.get("prompt"): match_caption = RE_CAPTION_PROMPT.search(prompt_dict.get("prompt")) - if not example_tuple: - caption_list = match_caption.group(1).split("|") - selected = random.choice(caption_list) - prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else '__caption__') - else: - prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), '__caption__') - - if '__caption__' in prompt_dict.get("prompt") and example_tuple: - logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - - while latents_list[idx]["prompt"] == '': - idx = (idx + 1) % len(latents_list) - if idx == 0: - break - - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{latents_list[idx]["prompt"]}') - logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") - prompt_dict["height"] = latents_list[idx]["height"] - logger.info(f"Original Image Height: {prompt_dict['height']}") - prompt_dict["width"] = latents_list[idx]["width"] - logger.info(f"Original Image Width: {prompt_dict['width']}") - prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"] - idx = (idx + 1) % len(latents_list) + if match_caption is not None: + if not example_tuple: + + if match_caption.group(3) is not None: + caption_list = match_caption.group(3).split("|") + selected = random.choice(caption_list) + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), selected if selected else f'Astronaut riding a horse on the moon') + logger.info(f"Backup prompt: {prompt_dict.get('prompt')}") + else: + while latents_list[idx]["prompt"] == '': + idx = (idx + 1) % len(latents_list) + if idx == 0: + break + + prompt_dict["prompt"] = prompt_dict.get("prompt").replace(match_caption.group(0), f'{latents_list[idx]["prompt"]}') + #logger.info(f"Replacement prompt: {prompt_dict.get('prompt')}") + prompt_dict["height"] = latents_list[idx]["height"] + #logger.info(f"Original Image Height: {prompt_dict['height']}") + prompt_dict["width"] = latents_list[idx]["width"] + #logger.info(f"Original Image Width: {prompt_dict['width']}") + prompt_dict["original_lantent"] = latents_list[idx]["original_lantent"] + idx = (idx + 1) % len(latents_list) prompt_dict["enum"] = i prompt_dict.pop("subset", None) @@ -5598,8 +5597,6 @@ def sample_images_common( with distributed_state.split_between_processes(prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists: if '__caption__' in prompt_dict.get("prompt"): - prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'Astronaut riding a horse on the moon') - logger.info("No training prompts loaded, replacing with placeholder 'Astronaut riding a horse on the moon' prompt.") sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet )