mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
supprt deep shink with regional lora, add prompter module
This commit is contained in:
168
gen_img.py
168
gen_img.py
@@ -3,6 +3,8 @@ import json
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
import inspect
|
||||
import time
|
||||
import zipfile
|
||||
@@ -333,6 +335,10 @@ class PipelineLike:
|
||||
self.scheduler = scheduler
|
||||
self.safety_checker = None
|
||||
|
||||
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
||||
self.clip_vision_processor: CLIPImageProcessor = None
|
||||
self.clip_vision_strength = 0.0
|
||||
|
||||
# Textual Inversion
|
||||
self.token_replacements_list = []
|
||||
for _ in range(len(self.text_encoders)):
|
||||
@@ -419,6 +425,7 @@ class PipelineLike:
|
||||
callback_steps: Optional[int] = 1,
|
||||
img2img_noise=None,
|
||||
clip_guide_images=None,
|
||||
emb_normalize_mode: str = "original",
|
||||
**kwargs,
|
||||
):
|
||||
# TODO support secondary prompt
|
||||
@@ -493,6 +500,7 @@ class PipelineLike:
|
||||
clip_skip=self.clip_skip,
|
||||
token_replacer=token_replacer,
|
||||
device=self.device,
|
||||
emb_normalize_mode=emb_normalize_mode,
|
||||
**kwargs,
|
||||
)
|
||||
tes_text_embs.append(text_embeddings)
|
||||
@@ -508,6 +516,7 @@ class PipelineLike:
|
||||
clip_skip=self.clip_skip,
|
||||
token_replacer=token_replacer,
|
||||
device=self.device,
|
||||
emb_normalize_mode=emb_normalize_mode,
|
||||
**kwargs,
|
||||
)
|
||||
tes_real_uncond_embs.append(real_uncond_embeddings)
|
||||
@@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings(
|
||||
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2
|
||||
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
||||
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||
if pool is None:
|
||||
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
|
||||
@@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings(
|
||||
else:
|
||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
||||
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
||||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
|
||||
if pool is not None:
|
||||
@@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings(
|
||||
clip_skip: int = 1,
|
||||
token_replacer=None,
|
||||
device=None,
|
||||
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
|
||||
**kwargs,
|
||||
):
|
||||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
@@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings(
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||
# →全体でいいんじゃないかな
|
||||
|
||||
if (not skip_parsing) and (not skip_weighting):
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if emb_normalize_mode == "abs":
|
||||
previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
elif emb_normalize_mode == "none":
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
|
||||
else: # "original"
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
|
||||
@@ -1427,6 +1455,27 @@ class BatchData(NamedTuple):
|
||||
ext: BatchDataExt
|
||||
|
||||
|
||||
class ListPrompter:
|
||||
def __init__(self, prompts: List[str]):
|
||||
self.prompts = prompts
|
||||
self.index = 0
|
||||
|
||||
def shuffle(self):
|
||||
random.shuffle(self.prompts)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompts)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.index >= len(self.prompts):
|
||||
self.index = 0 # reset
|
||||
return None
|
||||
|
||||
prompt = self.prompts[self.index]
|
||||
self.index += 1
|
||||
return prompt
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.fp16:
|
||||
dtype = torch.float16
|
||||
@@ -1951,15 +2000,35 @@ def main(args):
|
||||
token_embeds2[token_id] = embed
|
||||
|
||||
# promptを取得する
|
||||
prompt_list = None
|
||||
if args.from_file is not None:
|
||||
print(f"reading prompts from {args.from_file}")
|
||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||
prompt_list = f.read().splitlines()
|
||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
||||
prompter = ListPrompter(prompt_list)
|
||||
|
||||
elif args.from_module is not None:
|
||||
|
||||
def load_module_from_path(module_name, file_path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
print(f"reading prompts from module: {args.from_module}")
|
||||
prompt_module = load_module_from_path("prompt_module", args.from_module)
|
||||
|
||||
prompter = prompt_module.get_prompter(args, pipe, networks)
|
||||
|
||||
elif args.prompt is not None:
|
||||
prompt_list = [args.prompt]
|
||||
prompter = ListPrompter([args.prompt])
|
||||
|
||||
else:
|
||||
prompt_list = []
|
||||
prompter = None # interactive mode
|
||||
|
||||
if args.interactive:
|
||||
args.n_iter = 1
|
||||
@@ -2026,14 +2095,16 @@ def main(args):
|
||||
mask_images = None
|
||||
|
||||
# promptがないとき、画像のPngInfoから取得する
|
||||
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
|
||||
if init_images is not None and prompter is None and not args.interactive:
|
||||
print("get prompts from images' metadata")
|
||||
prompt_list = []
|
||||
for img in init_images:
|
||||
if "prompt" in img.text:
|
||||
prompt = img.text["prompt"]
|
||||
if "negative-prompt" in img.text:
|
||||
prompt += " --n " + img.text["negative-prompt"]
|
||||
prompt_list.append(prompt)
|
||||
prompter = ListPrompter(prompt_list)
|
||||
|
||||
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
|
||||
l = []
|
||||
@@ -2105,15 +2176,18 @@ def main(args):
|
||||
else:
|
||||
guide_images = None
|
||||
|
||||
# seed指定時はseedを決めておく
|
||||
# 新しい乱数生成器を作成する
|
||||
if args.seed is not None:
|
||||
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
|
||||
random.seed(args.seed)
|
||||
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
|
||||
if len(predefined_seeds) == 1:
|
||||
predefined_seeds[0] = args.seed
|
||||
if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1:
|
||||
# 引数のseedをそのまま使う
|
||||
def fixed_seed(*args, **kwargs):
|
||||
return args.seed
|
||||
|
||||
seed_random = SimpleNamespace(randint=fixed_seed)
|
||||
else:
|
||||
seed_random = random.Random(args.seed)
|
||||
else:
|
||||
predefined_seeds = None
|
||||
seed_random = random.Random()
|
||||
|
||||
# デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み)
|
||||
if args.W is None:
|
||||
@@ -2127,11 +2201,14 @@ def main(args):
|
||||
|
||||
for gen_iter in range(args.n_iter):
|
||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||
iter_seed = random.randint(0, 0x7FFFFFFF)
|
||||
if args.iter_same_seed:
|
||||
iter_seed = seed_random.randint(0, 2**32 - 1)
|
||||
else:
|
||||
iter_seed = None
|
||||
|
||||
# shuffle prompt list
|
||||
if args.shuffle_prompts:
|
||||
random.shuffle(prompt_list)
|
||||
prompter.shuffle()
|
||||
|
||||
# バッチ処理の関数
|
||||
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
||||
@@ -2352,7 +2429,8 @@ def main(args):
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
if regional_network:
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||
# TODO バッチから ds_ratio を取り出すべき
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio)
|
||||
|
||||
if not regional_network and network_pre_calc:
|
||||
for n in networks:
|
||||
@@ -2386,6 +2464,7 @@ def main(args):
|
||||
return_latents=return_latents,
|
||||
clip_prompts=clip_prompts,
|
||||
clip_guide_images=guide_images,
|
||||
emb_normalize_mode=args.emb_normalize_mode,
|
||||
)
|
||||
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||
return images
|
||||
@@ -2451,8 +2530,8 @@ def main(args):
|
||||
prompt_index = 0
|
||||
global_step = 0
|
||||
batch_data = []
|
||||
while args.interactive or prompt_index < len(prompt_list):
|
||||
if len(prompt_list) == 0:
|
||||
while True:
|
||||
if args.interactive:
|
||||
# interactive
|
||||
valid = False
|
||||
while not valid:
|
||||
@@ -2466,7 +2545,9 @@ def main(args):
|
||||
if not valid: # EOF, end app
|
||||
break
|
||||
else:
|
||||
raw_prompt = prompt_list[prompt_index]
|
||||
raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step)
|
||||
if raw_prompt is None:
|
||||
break
|
||||
|
||||
# sd-dynamic-prompts like variants:
|
||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||
@@ -2513,7 +2594,8 @@ def main(args):
|
||||
|
||||
prompt_args = raw_prompt.strip().split(" --")
|
||||
prompt = prompt_args[0]
|
||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
||||
length = len(prompter) if hasattr(prompter, "__len__") else 0
|
||||
print(f"prompt {prompt_index+1}/{length}: {prompt}")
|
||||
|
||||
for parg in prompt_args[1:]:
|
||||
try:
|
||||
@@ -2731,23 +2813,17 @@ def main(args):
|
||||
|
||||
# prepare seed
|
||||
if seeds is not None: # given in prompt
|
||||
# 数が足りないなら前のをそのまま使う
|
||||
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
|
||||
if len(seeds) > 0:
|
||||
seed = seeds.pop(0)
|
||||
else:
|
||||
if predefined_seeds is not None:
|
||||
if len(predefined_seeds) > 0:
|
||||
seed = predefined_seeds.pop(0)
|
||||
else:
|
||||
print("predefined seeds are exhausted")
|
||||
seed = None
|
||||
elif args.iter_same_seed:
|
||||
seeds = iter_seed
|
||||
if args.iter_same_seed:
|
||||
seed = iter_seed
|
||||
else:
|
||||
seed = None # 前のを消す
|
||||
|
||||
if seed is None:
|
||||
seed = random.randint(0, 0x7FFFFFFF)
|
||||
seed = seed_random.randint(0, 2**32 - 1)
|
||||
if args.interactive:
|
||||
print(f"seed: {seed}")
|
||||
|
||||
@@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from_module",
|
||||
type=str,
|
||||
default=None,
|
||||
help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
action="store_true",
|
||||
@@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--emb_normalize_mode",
|
||||
type=str,
|
||||
default="original",
|
||||
choices=["original", "none", "abs"],
|
||||
help="embedding normalization mode / embeddingの正規化モード",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user