supprt deep shink with regional lora, add prompter module

This commit is contained in:
Kohya S
2024-02-12 14:17:27 +09:00
parent d3745db764
commit cbe9c5dc06
2 changed files with 161 additions and 45 deletions

View File

@@ -3,6 +3,8 @@ import json
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob
import importlib
import importlib.util
import sys
import inspect
import time
import zipfile
@@ -333,6 +335,10 @@ class PipelineLike:
self.scheduler = scheduler
self.safety_checker = None
self.clip_vision_model: CLIPVisionModelWithProjection = None
self.clip_vision_processor: CLIPImageProcessor = None
self.clip_vision_strength = 0.0
# Textual Inversion
self.token_replacements_list = []
for _ in range(len(self.text_encoders)):
@@ -419,6 +425,7 @@ class PipelineLike:
callback_steps: Optional[int] = 1,
img2img_noise=None,
clip_guide_images=None,
emb_normalize_mode: str = "original",
**kwargs,
):
# TODO support secondary prompt
@@ -493,6 +500,7 @@ class PipelineLike:
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs,
)
tes_text_embs.append(text_embeddings)
@@ -508,6 +516,7 @@ class PipelineLike:
clip_skip=self.clip_skip,
token_replacer=token_replacer,
device=self.device,
emb_normalize_mode=emb_normalize_mode,
**kwargs,
)
tes_real_uncond_embs.append(real_uncond_embeddings)
@@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings(
clip_skip: int = 1,
token_replacer=None,
device=None,
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
**kwargs,
):
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
@@ -1239,7 +1249,25 @@ def get_weighted_text_embeddings(
# assign weights to the prompts and normalize in the sense of mean
# TODO: should we normalize by chunk or in a whole (current implementation)?
# →全体でいいんじゃないかな
if (not skip_parsing) and (not skip_weighting):
if emb_normalize_mode == "abs":
previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
if uncond_prompt is not None:
previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= uncond_weights.unsqueeze(-1)
current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
elif emb_normalize_mode == "none":
text_embeddings *= prompt_weights.unsqueeze(-1)
if uncond_prompt is not None:
uncond_embeddings *= uncond_weights.unsqueeze(-1)
else: # "original"
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
text_embeddings *= prompt_weights.unsqueeze(-1)
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
@@ -1427,6 +1455,27 @@ class BatchData(NamedTuple):
ext: BatchDataExt
class ListPrompter:
def __init__(self, prompts: List[str]):
self.prompts = prompts
self.index = 0
def shuffle(self):
random.shuffle(self.prompts)
def __len__(self):
return len(self.prompts)
def __call__(self, *args, **kwargs):
if self.index >= len(self.prompts):
self.index = 0 # reset
return None
prompt = self.prompts[self.index]
self.index += 1
return prompt
def main(args):
if args.fp16:
dtype = torch.float16
@@ -1951,15 +2000,35 @@ def main(args):
token_embeds2[token_id] = embed
# promptを取得する
prompt_list = None
if args.from_file is not None:
print(f"reading prompts from {args.from_file}")
with open(args.from_file, "r", encoding="utf-8") as f:
prompt_list = f.read().splitlines()
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
prompter = ListPrompter(prompt_list)
elif args.from_module is not None:
def load_module_from_path(module_name, file_path):
spec = importlib.util.spec_from_file_location(module_name, file_path)
if spec is None:
raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
print(f"reading prompts from module: {args.from_module}")
prompt_module = load_module_from_path("prompt_module", args.from_module)
prompter = prompt_module.get_prompter(args, pipe, networks)
elif args.prompt is not None:
prompt_list = [args.prompt]
prompter = ListPrompter([args.prompt])
else:
prompt_list = []
prompter = None # interactive mode
if args.interactive:
args.n_iter = 1
@@ -2026,14 +2095,16 @@ def main(args):
mask_images = None
# promptがないとき、画像のPngInfoから取得する
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
if init_images is not None and prompter is None and not args.interactive:
print("get prompts from images' metadata")
prompt_list = []
for img in init_images:
if "prompt" in img.text:
prompt = img.text["prompt"]
if "negative-prompt" in img.text:
prompt += " --n " + img.text["negative-prompt"]
prompt_list.append(prompt)
prompter = ListPrompter(prompt_list)
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
l = []
@@ -2105,15 +2176,18 @@ def main(args):
else:
guide_images = None
# seed指定時はseedを決めておく
# 新しい乱数生成器を作成する
if args.seed is not None:
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
random.seed(args.seed)
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
if len(predefined_seeds) == 1:
predefined_seeds[0] = args.seed
if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1:
# 引数のseedをそのまま使う
def fixed_seed(*args, **kwargs):
return args.seed
seed_random = SimpleNamespace(randint=fixed_seed)
else:
predefined_seeds = None
seed_random = random.Random(args.seed)
else:
seed_random = random.Random()
# デフォルト画像サイズを設定するimg2imgではこれらの値は無視されるまたはW*Hにリサイズ済み
if args.W is None:
@@ -2127,11 +2201,14 @@ def main(args):
for gen_iter in range(args.n_iter):
print(f"iteration {gen_iter+1}/{args.n_iter}")
iter_seed = random.randint(0, 0x7FFFFFFF)
if args.iter_same_seed:
iter_seed = seed_random.randint(0, 2**32 - 1)
else:
iter_seed = None
# shuffle prompt list
if args.shuffle_prompts:
random.shuffle(prompt_list)
prompter.shuffle()
# バッチ処理の関数
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
@@ -2352,7 +2429,8 @@ def main(args):
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
if regional_network:
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
# TODO バッチから ds_ratio を取り出すべき
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio)
if not regional_network and network_pre_calc:
for n in networks:
@@ -2386,6 +2464,7 @@ def main(args):
return_latents=return_latents,
clip_prompts=clip_prompts,
clip_guide_images=guide_images,
emb_normalize_mode=args.emb_normalize_mode,
)
if highres_1st and not args.highres_fix_save_1st: # return images or latents
return images
@@ -2451,8 +2530,8 @@ def main(args):
prompt_index = 0
global_step = 0
batch_data = []
while args.interactive or prompt_index < len(prompt_list):
if len(prompt_list) == 0:
while True:
if args.interactive:
# interactive
valid = False
while not valid:
@@ -2466,7 +2545,9 @@ def main(args):
if not valid: # EOF, end app
break
else:
raw_prompt = prompt_list[prompt_index]
raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step)
if raw_prompt is None:
break
# sd-dynamic-prompts like variants:
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
@@ -2513,7 +2594,8 @@ def main(args):
prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
length = len(prompter) if hasattr(prompter, "__len__") else 0
print(f"prompt {prompt_index+1}/{length}: {prompt}")
for parg in prompt_args[1:]:
try:
@@ -2731,23 +2813,17 @@ def main(args):
# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
if len(seeds) > 0:
seed = seeds.pop(0)
else:
if predefined_seeds is not None:
if len(predefined_seeds) > 0:
seed = predefined_seeds.pop(0)
else:
print("predefined seeds are exhausted")
seed = None
elif args.iter_same_seed:
seeds = iter_seed
if args.iter_same_seed:
seed = iter_seed
else:
seed = None # 前のを消す
if seed is None:
seed = random.randint(0, 0x7FFFFFFF)
seed = seed_random.randint(0, 2**32 - 1)
if args.interactive:
print(f"seed: {seed}")
@@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
)
parser.add_argument(
"--from_module",
type=str,
default=None,
help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む",
)
parser.add_argument(
"--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数"
)
parser.add_argument(
"--interactive",
action="store_true",
@@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
)
parser.add_argument(
"--emb_normalize_mode",
type=str,
default="original",
choices=["original", "none", "abs"],
help="embedding normalization mode / embeddingの正規化モード",
)
parser.add_argument(
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
)

View File

@@ -12,8 +12,10 @@ import numpy as np
import torch
import re
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -248,7 +250,8 @@ class LoRAInfModule(LoRAModule):
if mask is None:
# raise ValueError(f"mask is None for resolution {area}")
# emb_layers in SDXL doesn't have mask
# logger.info(f"mask is None for resolution {area}, {x.size()}")
# if "emb" not in self.lora_name:
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
if len(x.size()) != 4:
@@ -265,7 +268,9 @@ class LoRAInfModule(LoRAModule):
# apply mask for LoRA result
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
mask = self.get_mask_for_x(lx)
# logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}")
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
# mask = mask.squeeze(-1)
lx = lx * mask
x = self.org_forward(x)
@@ -514,7 +519,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else:
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks
if block_alphas is not None:
@@ -792,7 +799,9 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"create LoRA network from weights")
elif block_dims is not None:
logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None:
@@ -800,9 +809,13 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"conv_block_alphas: {conv_block_alphas}")
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances
def create_modules(
@@ -929,6 +942,10 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -1116,7 +1133,7 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.set_network(self)
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
self.batch_size = batch_size
self.num_sub_prompts = num_sub_prompts
self.current_size = (height, width)
@@ -1142,6 +1159,13 @@ class LoRANetwork(torch.nn.Module):
resize_add(h, w)
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
resize_add(h + h % 2, w + w % 2)
# deep shrink
if ds_ratio is not None:
hd = int(h * ds_ratio)
wd = int(w * ds_ratio)
resize_add(hd, wd)
h = (h + 1) // 2
w = (w + 1) // 2