mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
simplify multi-GPU sample generation
This commit is contained in:
@@ -4668,13 +4668,13 @@ def sample_images_common(
|
|||||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||||
return
|
return
|
||||||
|
|
||||||
distributed_state = PartialState() #testing implementation of multi gpu distributed inference
|
|
||||||
|
|
||||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||||
if not os.path.isfile(args.sample_prompts):
|
if not os.path.isfile(args.sample_prompts):
|
||||||
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||||
|
|
||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(distributed_state.device)
|
vae.to(distributed_state.device)
|
||||||
|
|
||||||
@@ -4686,10 +4686,6 @@ def sample_images_common(
|
|||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
|
|
||||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
|
||||||
# prompts = f.readlines()
|
|
||||||
|
|
||||||
if args.sample_prompts.endswith(".txt"):
|
if args.sample_prompts.endswith(".txt"):
|
||||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
@@ -4723,21 +4719,38 @@ def sample_images_common(
|
|||||||
save_dir = args.output_dir + "/sample"
|
save_dir = args.output_dir + "/sample"
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# preprocess prompts
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
prompt_dict = prompts[i]
|
||||||
|
if isinstance(prompt_dict, str):
|
||||||
|
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||||
|
prompts[i] = prompt_dict
|
||||||
|
assert isinstance(prompt_dict, dict)
|
||||||
|
|
||||||
|
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||||
|
prompt_dict["enum"] = i
|
||||||
|
prompt_dict.pop("subset", None)
|
||||||
|
|
||||||
|
# save random state to restore later
|
||||||
|
rng_state = torch.get_rng_state()
|
||||||
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
|
||||||
|
|
||||||
|
if distributed_state.num_processes <= 1:
|
||||||
|
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||||
|
with torch.no_grad():
|
||||||
|
for prompt_dict in prompts:
|
||||||
|
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
|
||||||
|
else:
|
||||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
|
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
|
||||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||||
per_process_prompts = generate_per_device_prompt_list(prompts, num_of_processes = distributed_state.num_processes, prompt_replacement = prompt_replacement)
|
per_process_prompts = [] # list of lists
|
||||||
|
for i in range(distributed_state.num_processes):
|
||||||
rng_state = torch.get_rng_state()
|
per_process_prompts.append(prompts[i::distributed_state.num_processes])
|
||||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
|
||||||
# True random sample image generation
|
|
||||||
torch.seed()
|
|
||||||
torch.cuda.seed()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||||
for prompt_dict in prompt_dict_lists[0]:
|
for prompt_dict in prompt_dict_lists[0]:
|
||||||
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=controlnet)
|
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
|
||||||
|
|
||||||
|
|
||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
@@ -4750,27 +4763,7 @@ def sample_images_common(
|
|||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
|
|
||||||
def generate_per_device_prompt_list(prompts, num_of_processes, prompt_replacement=None):
|
def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None):
|
||||||
|
|
||||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processess available (number of devices available)
|
|
||||||
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
|
||||||
per_process_prompts = [[] for i in range(num_of_processes)]
|
|
||||||
for i, prompt in enumerate(prompts):
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
prompt = line_to_prompt_dict(prompt)
|
|
||||||
assert isinstance(prompt, dict)
|
|
||||||
prompt.pop("subset", None) # Clean up subset key
|
|
||||||
prompt["enum"] = i
|
|
||||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
|
||||||
if prompt_replacement is not None:
|
|
||||||
prompt["prompt"] = prompt["prompt"].replace(prompt_replacement[0], prompt_replacement[1])
|
|
||||||
if prompt["negative_prompt"] is not None:
|
|
||||||
prompt["negative_prompt"] = prompt["negative_prompt"].replace(prompt_replacement[0], prompt_replacement[1])
|
|
||||||
# Refactor prompt replacement to here in order to simplify sample_image_inference function.
|
|
||||||
per_process_prompts[i % num_of_processes].append(prompt)
|
|
||||||
return per_process_prompts
|
|
||||||
|
|
||||||
def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, controlnet=None):
|
|
||||||
assert isinstance(prompt_dict, dict)
|
assert isinstance(prompt_dict, dict)
|
||||||
negative_prompt = prompt_dict.get("negative_prompt")
|
negative_prompt = prompt_dict.get("negative_prompt")
|
||||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||||
@@ -4782,9 +4775,18 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
|
|||||||
prompt: str = prompt_dict.get("prompt", "")
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||||
|
|
||||||
|
if prompt_replacement is not None:
|
||||||
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
if negative_prompt is not None:
|
||||||
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
# True random sample image generation
|
||||||
|
torch.seed()
|
||||||
|
torch.cuda.seed()
|
||||||
|
|
||||||
scheduler = get_my_scheduler(
|
scheduler = get_my_scheduler(
|
||||||
sample_sampler=sampler_name,
|
sample_sampler=sampler_name,
|
||||||
@@ -4819,7 +4821,10 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
|
|||||||
controlnet_image=controlnet_image,
|
controlnet_image=controlnet_image,
|
||||||
)
|
)
|
||||||
image = pipeline.latents_to_image(latents)[0]
|
image = pipeline.latents_to_image(latents)[0]
|
||||||
|
|
||||||
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||||
|
# but adding 'enum' to the filename should be enough
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||||
seed_suffix = "" if seed is None else f"_{seed}"
|
seed_suffix = "" if seed is None else f"_{seed}"
|
||||||
@@ -4827,11 +4832,8 @@ def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, p
|
|||||||
img_filename = (
|
img_filename = (
|
||||||
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||||
)
|
)
|
||||||
|
|
||||||
image.save(os.path.join(save_dir, img_filename))
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
if seed is not None:
|
|
||||||
torch.seed()
|
|
||||||
torch.cuda.seed()
|
|
||||||
# wandb有効時のみログを送信
|
# wandb有効時のみログを送信
|
||||||
try:
|
try:
|
||||||
wandb_tracker = accelerator.get_tracker("wandb")
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
|||||||
Reference in New Issue
Block a user