mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into dev_improve_log
This commit is contained in:
@@ -20,7 +20,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -4646,7 +4646,6 @@ def line_to_prompt_dict(line: str) -> dict:
|
|||||||
|
|
||||||
return prompt_dict
|
return prompt_dict
|
||||||
|
|
||||||
|
|
||||||
def sample_images_common(
|
def sample_images_common(
|
||||||
pipe_class,
|
pipe_class,
|
||||||
accelerator: Accelerator,
|
accelerator: Accelerator,
|
||||||
@@ -4664,6 +4663,7 @@ def sample_images_common(
|
|||||||
"""
|
"""
|
||||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if steps == 0:
|
if steps == 0:
|
||||||
if not args.sample_at_first:
|
if not args.sample_at_first:
|
||||||
return
|
return
|
||||||
@@ -4684,8 +4684,10 @@ def sample_images_common(
|
|||||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||||
|
|
||||||
org_vae_device = vae.device # CPUにいるはず
|
org_vae_device = vae.device # CPUにいるはず
|
||||||
vae.to(device)
|
vae.to(distributed_state.device)
|
||||||
|
|
||||||
# unwrap unet and text_encoder(s)
|
# unwrap unet and text_encoder(s)
|
||||||
unet = accelerator.unwrap_model(unet)
|
unet = accelerator.unwrap_model(unet)
|
||||||
@@ -4695,10 +4697,6 @@ def sample_images_common(
|
|||||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||||
|
|
||||||
# read prompts
|
# read prompts
|
||||||
|
|
||||||
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
|
|
||||||
# prompts = f.readlines()
|
|
||||||
|
|
||||||
if args.sample_prompts.endswith(".txt"):
|
if args.sample_prompts.endswith(".txt"):
|
||||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
@@ -4711,12 +4709,11 @@ def sample_images_common(
|
|||||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||||
prompts = json.load(f)
|
prompts = json.load(f)
|
||||||
|
|
||||||
schedulers: dict = {}
|
# schedulers: dict = {} cannot find where this is used
|
||||||
default_scheduler = get_my_scheduler(
|
default_scheduler = get_my_scheduler(
|
||||||
sample_sampler=args.sample_sampler,
|
sample_sampler=args.sample_sampler,
|
||||||
v_parameterization=args.v_parameterization,
|
v_parameterization=args.v_parameterization,
|
||||||
)
|
)
|
||||||
schedulers[args.sample_sampler] = default_scheduler
|
|
||||||
|
|
||||||
pipeline = pipe_class(
|
pipeline = pipe_class(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -4729,114 +4726,145 @@ def sample_images_common(
|
|||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
clip_skip=args.clip_skip,
|
clip_skip=args.clip_skip,
|
||||||
)
|
)
|
||||||
pipeline.to(device)
|
pipeline.to(distributed_state.device)
|
||||||
|
|
||||||
save_dir = args.output_dir + "/sample"
|
save_dir = args.output_dir + "/sample"
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# preprocess prompts
|
||||||
|
for i in range(len(prompts)):
|
||||||
|
prompt_dict = prompts[i]
|
||||||
|
if isinstance(prompt_dict, str):
|
||||||
|
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||||
|
prompts[i] = prompt_dict
|
||||||
|
assert isinstance(prompt_dict, dict)
|
||||||
|
|
||||||
|
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||||
|
prompt_dict["enum"] = i
|
||||||
|
prompt_dict.pop("subset", None)
|
||||||
|
|
||||||
|
# save random state to restore later
|
||||||
rng_state = torch.get_rng_state()
|
rng_state = torch.get_rng_state()
|
||||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
|
||||||
|
|
||||||
with torch.no_grad():
|
if distributed_state.num_processes <= 1:
|
||||||
# with accelerator.autocast():
|
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
|
||||||
for i, prompt_dict in enumerate(prompts):
|
with torch.no_grad():
|
||||||
if not accelerator.is_main_process:
|
for prompt_dict in prompts:
|
||||||
continue
|
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
|
||||||
|
else:
|
||||||
|
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||||
|
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
|
||||||
|
per_process_prompts = [] # list of lists
|
||||||
|
for i in range(distributed_state.num_processes):
|
||||||
|
per_process_prompts.append(prompts[i::distributed_state.num_processes])
|
||||||
|
|
||||||
if isinstance(prompt_dict, str):
|
with torch.no_grad():
|
||||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||||
|
for prompt_dict in prompt_dict_lists[0]:
|
||||||
assert isinstance(prompt_dict, dict)
|
sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
|
||||||
negative_prompt = prompt_dict.get("negative_prompt")
|
|
||||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
|
||||||
width = prompt_dict.get("width", 512)
|
|
||||||
height = prompt_dict.get("height", 512)
|
|
||||||
scale = prompt_dict.get("scale", 7.5)
|
|
||||||
seed = prompt_dict.get("seed")
|
|
||||||
controlnet_image = prompt_dict.get("controlnet_image")
|
|
||||||
prompt: str = prompt_dict.get("prompt", "")
|
|
||||||
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
|
||||||
|
|
||||||
if seed is not None:
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
scheduler = schedulers.get(sampler_name)
|
|
||||||
if scheduler is None:
|
|
||||||
scheduler = get_my_scheduler(
|
|
||||||
sample_sampler=sampler_name,
|
|
||||||
v_parameterization=args.v_parameterization,
|
|
||||||
)
|
|
||||||
schedulers[sampler_name] = scheduler
|
|
||||||
pipeline.scheduler = scheduler
|
|
||||||
|
|
||||||
if prompt_replacement is not None:
|
|
||||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
|
||||||
if negative_prompt is not None:
|
|
||||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
|
||||||
|
|
||||||
if controlnet_image is not None:
|
|
||||||
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
|
||||||
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
|
||||||
|
|
||||||
height = max(64, height - height % 8) # round to divisible by 8
|
|
||||||
width = max(64, width - width % 8) # round to divisible by 8
|
|
||||||
logger.info(f"prompt: {prompt}")
|
|
||||||
logger.info(f"negative_prompt: {negative_prompt}")
|
|
||||||
logger.info(f"height: {height}")
|
|
||||||
logger.info(f"width: {width}")
|
|
||||||
logger.info(f"sample_steps: {sample_steps}")
|
|
||||||
logger.info(f"scale: {scale}")
|
|
||||||
logger.info(f"sample_sampler: {sampler_name}")
|
|
||||||
if seed is not None:
|
|
||||||
logger.info(f"seed: {seed}")
|
|
||||||
with accelerator.autocast():
|
|
||||||
latents = pipeline(
|
|
||||||
prompt=prompt,
|
|
||||||
height=height,
|
|
||||||
width=width,
|
|
||||||
num_inference_steps=sample_steps,
|
|
||||||
guidance_scale=scale,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
controlnet=controlnet,
|
|
||||||
controlnet_image=controlnet_image,
|
|
||||||
)
|
|
||||||
|
|
||||||
image = pipeline.latents_to_image(latents)[0]
|
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
|
||||||
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
|
||||||
seed_suffix = "" if seed is None else f"_{seed}"
|
|
||||||
img_filename = (
|
|
||||||
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
|
|
||||||
)
|
|
||||||
|
|
||||||
image.save(os.path.join(save_dir, img_filename))
|
|
||||||
|
|
||||||
# wandb有効時のみログを送信
|
|
||||||
try:
|
|
||||||
wandb_tracker = accelerator.get_tracker("wandb")
|
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
|
||||||
raise ImportError("No wandb / wandb がインストールされていないようです")
|
|
||||||
|
|
||||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
|
||||||
except: # wandb 無効時
|
|
||||||
pass
|
|
||||||
|
|
||||||
# clear pipeline and cache to reduce vram usage
|
# clear pipeline and cache to reduce vram usage
|
||||||
del pipeline
|
del pipeline
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
with torch.cuda.device(torch.cuda.current_device()):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
torch.set_rng_state(rng_state)
|
torch.set_rng_state(rng_state)
|
||||||
if cuda_rng_state is not None:
|
if cuda_rng_state is not None:
|
||||||
torch.cuda.set_rng_state(cuda_rng_state)
|
torch.cuda.set_rng_state(cuda_rng_state)
|
||||||
vae.to(org_vae_device)
|
vae.to(org_vae_device)
|
||||||
|
|
||||||
|
def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None):
|
||||||
|
assert isinstance(prompt_dict, dict)
|
||||||
|
negative_prompt = prompt_dict.get("negative_prompt")
|
||||||
|
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||||
|
width = prompt_dict.get("width", 512)
|
||||||
|
height = prompt_dict.get("height", 512)
|
||||||
|
scale = prompt_dict.get("scale", 7.5)
|
||||||
|
seed = prompt_dict.get("seed")
|
||||||
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
|
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||||
|
|
||||||
|
if prompt_replacement is not None:
|
||||||
|
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
if negative_prompt is not None:
|
||||||
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
else:
|
||||||
|
# True random sample image generation
|
||||||
|
torch.seed()
|
||||||
|
torch.cuda.seed()
|
||||||
|
|
||||||
|
scheduler = get_my_scheduler(
|
||||||
|
sample_sampler=sampler_name,
|
||||||
|
v_parameterization=args.v_parameterization,
|
||||||
|
)
|
||||||
|
pipeline.scheduler = scheduler
|
||||||
|
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||||
|
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||||
|
|
||||||
|
height = max(64, height - height % 8) # round to divisible by 8
|
||||||
|
width = max(64, width - width % 8) # round to divisible by 8
|
||||||
|
logger.info(f"prompt: {prompt}")
|
||||||
|
logger.info(f"negative_prompt: {negative_prompt}")
|
||||||
|
logger.info(f"height: {height}")
|
||||||
|
logger.info(f"width: {width}")
|
||||||
|
logger.info(f"sample_steps: {sample_steps}")
|
||||||
|
logger.info(f"scale: {scale}")
|
||||||
|
logger.info(f"sample_sampler: {sampler_name}")
|
||||||
|
if seed is not None:
|
||||||
|
logger.info(f"seed: {seed}")
|
||||||
|
with accelerator.autocast():
|
||||||
|
latents = pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_inference_steps=sample_steps,
|
||||||
|
guidance_scale=scale,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
controlnet=controlnet,
|
||||||
|
controlnet_image=controlnet_image,
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.cuda.device(torch.cuda.current_device()):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
image = pipeline.latents_to_image(latents)[0]
|
||||||
|
|
||||||
|
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
|
||||||
|
# but adding 'enum' to the filename should be enough
|
||||||
|
|
||||||
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
|
||||||
|
seed_suffix = "" if seed is None else f"_{seed}"
|
||||||
|
i: int = prompt_dict["enum"]
|
||||||
|
img_filename = (
|
||||||
|
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
|
||||||
|
)
|
||||||
|
image.save(os.path.join(save_dir, img_filename))
|
||||||
|
|
||||||
|
# wandb有効時のみログを送信
|
||||||
|
try:
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
|
||||||
|
raise ImportError("No wandb / wandb がインストールされていないようです")
|
||||||
|
|
||||||
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
|
||||||
|
except: # wandb 無効時
|
||||||
|
pass
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# region 前処理用
|
# region 前処理用
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user