diff --git a/library/train_util.py b/library/train_util.py index d4a0ec45..8c114297 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5518,7 +5518,7 @@ def sample_images_common( idx = (idx + 1) % len(example_tuple[1]) if idx == 0: break - prompts[i]["prompt"] = prompt_dict.get("prompt").replace('__caption__', 'example_tuple[1][idx]') + prompts[i]["prompt"] = prompt_dict.get("prompt").replace('__caption__', example_tuple[1][idx]) prompts[i]["height"] = example_tuple[0].shape[2] * 8 prompts[i]["width"] = example_tuple[0].shape[3] * 8 prompts[i]["original_lantent"] = example_tuple[0][idx].unsqueeze(0) @@ -5540,6 +5540,9 @@ def sample_images_common( # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. with torch.no_grad(): for prompt_dict in prompts: + if prompt_dict["prompt"] == '__caption__': + logger.info("No training prompts loaded, skipping '__caption__' prompt.") + continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) @@ -5553,6 +5556,9 @@ def sample_images_common( with torch.no_grad(): with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: for prompt_dict in prompt_dict_lists[0]: + if prompt_dict["prompt"] == '__caption__': + logger.info("No training prompts loaded, skipping '__caption__' prompt.") + continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet )