Update train_util.py

This commit is contained in:
DKnight54
2025-01-31 20:02:57 +08:00
committed by GitHub
parent a979ea5a50
commit 3bb1d9d389

View File

@@ -5472,18 +5472,23 @@ def sample_images_common(
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
if distributed_state.is_main_process:
# Load prompts into prompts list on main process only
if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines()
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
elif args.sample_prompts.endswith(".toml"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
data = toml.load(f)
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
elif args.sample_prompts.endswith(".json"):
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
else:
prompts = [] # Init empty prompts list for sub processes.
# schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
@@ -5502,38 +5507,41 @@ def sample_images_common(
clip_skip=args.clip_skip,
)
pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
# preprocess prompts
idx = 0
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
logger.info(f"Original prompt: {prompt_dict.get('prompt')}")
while example_tuple[1][idx] == '':
if distributed_state.is_main_process:
#Create output folder and preprocess prompts on main process only.
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
idx = 0
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
logger.info(f"Original prompt: {prompt_dict.get('prompt')}")
while example_tuple[1][idx] == '':
idx = (idx + 1) % len(example_tuple[1])
if idx == 0:
break
prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}')
logger.info(f"Replacement prompt: {prompt_dict["prompt"]}")
prompt_dict["height"] = example_tuple[0].shape[2] * 8
logger.info(f"Original Image Height: {prompt_dict["height"]}")
prompt_dict["width"] = example_tuple[0].shape[3] * 8
logger.info(f"Original Image Width: {prompt_dict["width"]}")
prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
idx = (idx + 1) % len(example_tuple[1])
if idx == 0:
break
prompt_dict["prompt"] = f"{example_tuple[1][idx]}"
logger.info(f"Replacement prompt: {example_tuple[1][idx]}")
prompt_dict["height"] = example_tuple[0].shape[2] * 8
logger.info(f"Original Image Height: {prompt_dict["height"]}")
prompt_dict["width"] = example_tuple[0].shape[3] * 8
logger.info(f"Original Image Width: {prompt_dict["width"]}")
prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
idx = (idx + 1) % len(example_tuple[1])
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
prompts[i] = prompt_dict
logger.info(f"Current prompt: {prompts[i].get('prompt')}")
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
prompts[i] = prompt_dict
logger.info(f"Current prompt: {prompts[i].get('prompt')}")
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
@@ -5546,32 +5554,34 @@ def sample_images_common(
except Exception:
pass
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad():
for prompt_dict in prompts:
if prompt_dict["prompt"] == '__caption__':
if distributed_state.num_processes > 1 and distributed_state.is_main_process:
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
prompts = []
for prompt in per_process_prompts:
prompts.extend(prompt)
distributed_state.wait_for_everyone()
per_process_prompts = gather_object(prompts)
prompts = []
for prompt in per_process_prompts:
prompts.extend(prompt)
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
with torch.no_grad():
with distributed_state.split_between_processes(prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists:
if '__caption__' in prompt_dict.get("prompt"):
logger.info("No training prompts loaded, skipping '__caption__' prompt.")
continue
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
if prompt_dict["prompt"] == '__caption__':
logger.info("No training prompts loaded, skipping '__caption__' prompt.")
continue
sample_image_inference(
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
)
# clear pipeline and cache to reduce vram usage
del pipeline
@@ -5699,6 +5709,7 @@ def sample_image_inference(
image = pipeline.latents_to_image(latents)[0]
if "original_lantent" in prompt_dict:
#Prevent out of VRAM error
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()