mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Update train_util.py
This commit is contained in:
@@ -5472,18 +5472,23 @@ def sample_images_common(
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
# read prompts
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
|
||||
if distributed_state.is_main_process:
|
||||
# Load prompts into prompts list on main process only
|
||||
if args.sample_prompts.endswith(".txt"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif args.sample_prompts.endswith(".toml"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif args.sample_prompts.endswith(".json"):
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
else:
|
||||
prompts = [] # Init empty prompts list for sub processes.
|
||||
|
||||
# schedulers: dict = {} cannot find where this is used
|
||||
default_scheduler = get_my_scheduler(
|
||||
sample_sampler=args.sample_sampler,
|
||||
@@ -5502,38 +5507,41 @@ def sample_images_common(
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.to(distributed_state.device)
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
|
||||
# preprocess prompts
|
||||
idx = 0
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
|
||||
logger.info(f"Original prompt: {prompt_dict.get('prompt')}")
|
||||
|
||||
while example_tuple[1][idx] == '':
|
||||
if distributed_state.is_main_process:
|
||||
#Create output folder and preprocess prompts on main process only.
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
idx = 0
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
if '__caption__' in prompt_dict.get("prompt") and example_tuple:
|
||||
logger.info(f"Original prompt: {prompt_dict.get('prompt')}")
|
||||
|
||||
while example_tuple[1][idx] == '':
|
||||
idx = (idx + 1) % len(example_tuple[1])
|
||||
if idx == 0:
|
||||
break
|
||||
prompt_dict["prompt"] = prompt_dict.get("prompt").replace('__caption__', f'{example_tuple[1][idx]}')
|
||||
logger.info(f"Replacement prompt: {prompt_dict["prompt"]}")
|
||||
prompt_dict["height"] = example_tuple[0].shape[2] * 8
|
||||
logger.info(f"Original Image Height: {prompt_dict["height"]}")
|
||||
prompt_dict["width"] = example_tuple[0].shape[3] * 8
|
||||
logger.info(f"Original Image Width: {prompt_dict["width"]}")
|
||||
prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
|
||||
idx = (idx + 1) % len(example_tuple[1])
|
||||
if idx == 0:
|
||||
break
|
||||
prompt_dict["prompt"] = f"{example_tuple[1][idx]}"
|
||||
logger.info(f"Replacement prompt: {example_tuple[1][idx]}")
|
||||
prompt_dict["height"] = example_tuple[0].shape[2] * 8
|
||||
logger.info(f"Original Image Height: {prompt_dict["height"]}")
|
||||
prompt_dict["width"] = example_tuple[0].shape[3] * 8
|
||||
logger.info(f"Original Image Width: {prompt_dict["width"]}")
|
||||
prompt_dict["original_lantent"] = example_tuple[0][idx].unsqueeze(0)
|
||||
idx = (idx + 1) % len(example_tuple[1])
|
||||
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
prompts[i] = prompt_dict
|
||||
logger.info(f"Current prompt: {prompts[i].get('prompt')}")
|
||||
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
prompts[i] = prompt_dict
|
||||
logger.info(f"Current prompt: {prompts[i].get('prompt')}")
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
|
||||
@@ -5546,32 +5554,34 @@ def sample_images_common(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if distributed_state.num_processes <= 1:
|
||||
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||
with torch.no_grad():
|
||||
for prompt_dict in prompts:
|
||||
if prompt_dict["prompt"] == '__caption__':
|
||||
if distributed_state.num_processes > 1 and distributed_state.is_main_process:
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
prompts = []
|
||||
for prompt in per_process_prompts:
|
||||
prompts.extend(prompt)
|
||||
distributed_state.wait_for_everyone()
|
||||
per_process_prompts = gather_object(prompts)
|
||||
prompts = []
|
||||
for prompt in per_process_prompts:
|
||||
prompts.extend(prompt)
|
||||
|
||||
|
||||
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists:
|
||||
if '__caption__' in prompt_dict.get("prompt"):
|
||||
logger.info("No training prompts loaded, skipping '__caption__' prompt.")
|
||||
continue
|
||||
sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||
per_process_prompts = [] # list of lists
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with torch.no_grad():
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
for prompt_dict in prompt_dict_lists[0]:
|
||||
if prompt_dict["prompt"] == '__caption__':
|
||||
logger.info("No training prompts loaded, skipping '__caption__' prompt.")
|
||||
continue
|
||||
sample_image_inference(
|
||||
accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet
|
||||
)
|
||||
|
||||
# clear pipeline and cache to reduce vram usage
|
||||
del pipeline
|
||||
@@ -5699,6 +5709,7 @@ def sample_image_inference(
|
||||
|
||||
image = pipeline.latents_to_image(latents)[0]
|
||||
if "original_lantent" in prompt_dict:
|
||||
#Prevent out of VRAM error
|
||||
if torch.cuda.is_available():
|
||||
with torch.cuda.device(torch.cuda.current_device()):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user