diff --git a/library/train_util.py b/library/train_util.py index b32b80d2..1d4bf90d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5432,7 +5432,7 @@ def sample_images_common( tokenizer, text_encoder, unet, - latents_list=None, + example_tuple=None, prompt_replacement=None, controlnet=None, ): @@ -5511,12 +5511,23 @@ def sample_images_common( # preprocess prompts - + if example_tuple: + latents_list = [] + for idx in range(len(example_tuple[1])): + latent_dict = {} + latent_dict["prompt"] = example_tuple[1][idx] + latent_dict["height"] = example_tuple[0].shape[2] * 8 + latent_dict["width"] = example_tuple[0].shape[3] * 8 + latent_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) + latents_list.append(latent_dict) + distributed_state.wait_for_everyone() + latents_list = gather_object(latents_list) save_dir = args.output_dir + "/sample" if distributed_state.is_main_process: #Create output folder and preprocess prompts on main process only. os.makedirs(save_dir, exist_ok=True) idx = 0 + for i in range(len(prompts)): prompt_dict = prompts[i] if isinstance(prompt_dict, str): @@ -5524,7 +5535,7 @@ def sample_images_common( prompts[i] = prompt_dict assert isinstance(prompt_dict, dict) - if '__caption__' in prompt_dict.get("prompt") and latents_list: + if '__caption__' in prompt_dict.get("prompt") and example_tuple: logger.info(f"Original prompt: {prompt_dict.get('prompt')}") while latents_list[idx]["prompt"] == '':