From 3bb1d9d38984d363ef8c00bafa9570beea3121d0 Mon Sep 17 00:00:00 2001 From: DKnight54 <126916963+DKnight54@users.noreply.github.com> Date: Fri, 31 Jan 2025 20:02:57 +0800 Subject: [PATCH] Update train_util.py --- library/train_util.py | 135 +++++++++++++++++++++++------------------- 1 file changed, 73 insertions(+), 62 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c8021fec..8b663881 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5472,18 +5472,23 @@ def sample_images_common( text_encoder = accelerator.unwrap_model(text_encoder) # read prompts - if args.sample_prompts.endswith(".txt"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - lines = f.readlines() - prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] - elif args.sample_prompts.endswith(".toml"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - data = toml.load(f) - prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] - elif args.sample_prompts.endswith(".json"): - with open(args.sample_prompts, "r", encoding="utf-8") as f: - prompts = json.load(f) - + + if distributed_state.is_main_process: + # Load prompts into prompts list on main process only + if args.sample_prompts.endswith(".txt"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + lines = f.readlines() + prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"] + elif args.sample_prompts.endswith(".toml"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + data = toml.load(f) + prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]] + elif args.sample_prompts.endswith(".json"): + with open(args.sample_prompts, "r", encoding="utf-8") as f: + prompts = json.load(f) + else: + prompts = [] # Init empty prompts list for sub processes. + # schedulers: dict = {} cannot find where this is used default_scheduler = get_my_scheduler( sample_sampler=args.sample_sampler, @@ -5502,38 +5507,41 @@ def sample_images_common( clip_skip=args.clip_skip, ) pipeline.to(distributed_state.device) - save_dir = args.output_dir + "/sample" - os.makedirs(save_dir, exist_ok=True) + # preprocess prompts - idx = 0 - for i in range(len(prompts)): - prompt_dict = prompts[i] - if isinstance(prompt_dict, str): - prompt_dict = line_to_prompt_dict(prompt_dict) - prompts[i] = prompt_dict - assert isinstance(prompt_dict, dict) - - if '__caption__' in prompt_dict.get("prompt") and example_tuple: - logger.info(f"Original prompt: {prompt_dict.get('prompt')}") - - while example_tuple[1][idx] == '': + if distributed_state.is_main_process: + #Create output folder and preprocess prompts on main process only. + save_dir = args.output_dir + "/sample" + os.makedirs(save_dir, exist_ok=True) + idx = 0 + for i in range(len(prompts)): + prompt_dict = prompts[i] + if isinstance(prompt_dict, str): + prompt_dict = line_to_prompt_dict(prompt_dict) + prompts[i] = prompt_dict + assert isinstance(prompt_dict, dict) + + if '__caption__' in prompt_dict.get("prompt") and example_tuple: + logger.info(f"Original prompt: {prompt_dict.get('prompt')}") + + while example_tuple[1][idx] == '': + idx = (idx + 1) % len(example_tuple[1]) + if idx == 0: + break + prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}') + logger.info(f"Replacement prompt: {prompt_dict["prompt"]}") + prompt_dict["height"] = example_tuple[0].shape[2] * 8 + logger.info(f"Original Image Height: {prompt_dict["height"]}") + prompt_dict["width"] = example_tuple[0].shape[3] * 8 + logger.info(f"Original Image Width: {prompt_dict["width"]}") + prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) idx = (idx + 1) % len(example_tuple[1]) - if idx == 0: - break - prompt_dict["prompt"] = f"{example_tuple[1][idx]}" - logger.info(f"Replacement prompt: {example_tuple[1][idx]}") - prompt_dict["height"] = example_tuple[0].shape[2] * 8 - logger.info(f"Original Image Height: {prompt_dict["height"]}") - prompt_dict["width"] = example_tuple[0].shape[3] * 8 - logger.info(f"Original Image Width: {prompt_dict["width"]}") - prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0) - idx = (idx + 1) % len(example_tuple[1]) - - prompt_dict["enum"] = i - prompt_dict.pop("subset", None) - prompts[i] = prompt_dict - logger.info(f"Current prompt: {prompts[i].get('prompt')}") + + prompt_dict["enum"] = i + prompt_dict.pop("subset", None) + prompts[i] = prompt_dict + logger.info(f"Current prompt: {prompts[i].get('prompt')}") # Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict. @@ -5546,32 +5554,34 @@ def sample_images_common( except Exception: pass - if distributed_state.num_processes <= 1: - # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts. - with torch.no_grad(): - for prompt_dict in prompts: - if prompt_dict["prompt"] == '__caption__': + if distributed_state.num_processes > 1 and distributed_state.is_main_process: + per_process_prompts = [] # list of lists + for i in range(distributed_state.num_processes): + per_process_prompts.append(prompts[i :: distributed_state.num_processes]) + prompts = [] + for prompt in per_process_prompts: + prompts.extend(prompt) + distributed_state.wait_for_everyone() + per_process_prompts = gather_object(prompts) + prompts = [] + for prompt in per_process_prompts: + prompts.extend(prompt) + + + + # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) + # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. + + + with torch.no_grad(): + with distributed_state.split_between_processes(prompts) as prompt_dict_lists: + for prompt_dict in prompt_dict_lists: + if '__caption__' in prompt_dict.get("prompt"): logger.info("No training prompts loaded, skipping '__caption__' prompt.") continue sample_image_inference( accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet ) - else: - # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available) - # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical. - per_process_prompts = [] # list of lists - for i in range(distributed_state.num_processes): - per_process_prompts.append(prompts[i :: distributed_state.num_processes]) - - with torch.no_grad(): - with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists: - for prompt_dict in prompt_dict_lists[0]: - if prompt_dict["prompt"] == '__caption__': - logger.info("No training prompts loaded, skipping '__caption__' prompt.") - continue - sample_image_inference( - accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet - ) # clear pipeline and cache to reduce vram usage del pipeline @@ -5699,6 +5709,7 @@ def sample_image_inference( image = pipeline.latents_to_image(latents)[0] if "original_lantent" in prompt_dict: + #Prevent out of VRAM error if torch.cuda.is_available(): with torch.cuda.device(torch.cuda.current_device()): torch.cuda.empty_cache()