supprt deep shink with regional lora, add prompter module

This commit is contained in:
Kohya S
2024-02-12 14:17:27 +09:00
parent d3745db764
commit cbe9c5dc06
2 changed files with 161 additions and 45 deletions

View File

@@ -3,6 +3,8 @@ import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
import importlib.util
import sys
import inspect
import time
import zipfile
@@ -333,6 +335,10 @@ class PipelineLike:
self.scheduler = scheduler
self.safety_checker = None
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
@@ -419,6 +425,7 @@ class PipelineLike:
callback_steps: Optional[int] = 1,
img2img_noise=None,
clip_guide_images=None,
emb_normalize_mode: str = "original",
**kwargs,
):
# TODO support secondary prompt
@@ -493,6 +500,7 @@ class PipelineLike:
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs,
)
tes_text_embs.append(text_embeddings)
@@ -508,6 +516,7 @@ class PipelineLike:
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs,
)
tes_real_uncond_embs.append(real_uncond_embeddings)
@@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings(
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm
if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
if pool is None:
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
@@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings(
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
if not is_sdxl: # SD 1.5 requires final_layer_norm
if not is_sdxl: # SD 1.5 requires final_layer_norm
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
if pool is not None:
@@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings(
clip_skip: int = 1,
token_replacer=None,
device=None,
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
**kwargs,
):
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings(
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
# →全体でいいんじゃないかな
if (not skip_parsing) and (not skip_weighting):
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if emb_normalize_mode == "abs":
previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
elif emb_normalize_mode == "none":
text_embeddings *= prompt_weights.unsqueeze(-1)
if uncond_prompt is not None:
uncond_embeddings *= uncond_weights.unsqueeze(-1)
else: # "original"
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
@@ -1427,6 +1455,27 @@ class BatchData(NamedTuple):
ext: BatchDataExt
class ListPrompter:
def __init__(self, prompts: List[str]):
self.prompts = prompts
self.index = 0
def shuffle(self):
random.shuffle(self.prompts)
def __len__(self):
return len(self.prompts)
def __call__(self, *args, **kwargs):
if self.index >= len(self.prompts):
self.index = 0 # reset
return None
prompt = self.prompts[self.index]
self.index += 1
return prompt
def main(args):
if args.fp16:
dtype = torch.float16
@@ -1951,15 +2000,35 @@ def main(args):
token_embeds2[token_id] = embed
# promptを取得する
prompt_list = None
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
prompter = ListPrompter(prompt_list)
elif args.from_module is not None:
def load_module_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
print(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)
prompter = prompt_module.get_prompter(args, pipe, networks)
elif args.prompt is not None:
prompt_list = [args.prompt]
prompter = ListPrompter([args.prompt])
else:
prompt_list = []
prompter = None # interactive mode
if args.interactive:
args.n_iter = 1
@@ -2026,14 +2095,16 @@ def main(args):
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
if init_images is not None and prompter is None and not args.interactive:
print("get prompts from images' metadata")
prompt_list = []
for img in init_images:
if "prompt" in img.text:
prompt = img.text["prompt"]
if "negative-prompt" in img.text:
prompt += " --n " + img.text["negative-prompt"]
prompt_list.append(prompt)
prompter = ListPrompter(prompt_list)
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
l = []
@@ -2105,15 +2176,18 @@ def main(args):
else:
guide_images = None
# seed指定時はseedを決めておく
# 新しい乱数生成器を作成する
if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1:
predefined_seeds[0] = args.seed
if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1:
# 引数のseedをそのまま使う
def fixed_seed(*args, **kwargs):
return args.seed
seed_random = SimpleNamespace(randint=fixed_seed)
else:
seed_random = random.Random(args.seed)
else:
predefined_seeds = None
seed_random = random.Random()
# デフォルト画像サイズを設定するimg2imgではこれらの値は無視されるまたはW*Hにリサイズ済み
if args.W is None:
@@ -2127,11 +2201,14 @@ def main(args):
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF)
if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
iter_seed = None
# shuffle prompt list
if args.shuffle_prompts:
random.shuffle(prompt_list)
prompter.shuffle()
# バッチ処理の関数
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
@@ -2352,7 +2429,8 @@ def main(args):
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
# TODO バッチから ds_ratio を取り出すべき
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio)
if not regional_network and network_pre_calc:
for n in networks:
@@ -2386,6 +2464,7 @@ def main(args):
return_latents=return_latents,
clip_prompts=clip_prompts,
clip_guide_images=guide_images,
emb_normalize_mode=args.emb_normalize_mode,
)
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images
@@ -2451,8 +2530,8 @@ def main(args):
prompt_index = 0
global_step = 0
batch_data = []
while args.interactive or prompt_index < len(prompt_list):
if len(prompt_list) == 0:
while True:
if args.interactive:
# interactive
valid = False
while not valid:
@@ -2466,7 +2545,9 @@ def main(args):
if not valid: # EOF, end app
break
else:
raw_prompt = prompt_list[prompt_index]
raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step)
if raw_prompt is None:
break
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
@@ -2513,7 +2594,8 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
length = len(prompter) if hasattr(prompter, "__len__") else 0
print(f"prompt {prompt_index+1}/{length}: {prompt}")
for parg in prompt_args[1:]:
try:
@@ -2731,23 +2813,17 @@ def main(args):
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
if len(seeds) > 0:
seed = seeds.pop(0)
else:
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
if args.iter_same_seed:
seed = iter_seed
else:
seed = None # 前のを消す
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
print(f"seed: {seed}")
@@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
)
parser.add_argument(
"--from_module",
type=str,
default=None,
help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む",
)
parser.add_argument(
"--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数"
)
parser.add_argument(
"--interactive",
action="store_true",
@@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
)
parser.add_argument(
"--emb_normalize_mode",
type=str,
default="original",
choices=["original", "none", "abs"],
help="embedding normalization mode / embeddingの正規化モード",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
)