mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
supprt deep shink with regional lora, add prompter module
This commit is contained in:
168
gen_img.py
168
gen_img.py
@@ -3,6 +3,8 @@ import json
|
|||||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
import zipfile
|
import zipfile
|
||||||
@@ -333,6 +335,10 @@ class PipelineLike:
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
|
self.clip_vision_model: CLIPVisionModelWithProjection = None
|
||||||
|
self.clip_vision_processor: CLIPImageProcessor = None
|
||||||
|
self.clip_vision_strength = 0.0
|
||||||
|
|
||||||
# Textual Inversion
|
# Textual Inversion
|
||||||
self.token_replacements_list = []
|
self.token_replacements_list = []
|
||||||
for _ in range(len(self.text_encoders)):
|
for _ in range(len(self.text_encoders)):
|
||||||
@@ -419,6 +425,7 @@ class PipelineLike:
|
|||||||
callback_steps: Optional[int] = 1,
|
callback_steps: Optional[int] = 1,
|
||||||
img2img_noise=None,
|
img2img_noise=None,
|
||||||
clip_guide_images=None,
|
clip_guide_images=None,
|
||||||
|
emb_normalize_mode: str = "original",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# TODO support secondary prompt
|
# TODO support secondary prompt
|
||||||
@@ -493,6 +500,7 @@ class PipelineLike:
|
|||||||
clip_skip=self.clip_skip,
|
clip_skip=self.clip_skip,
|
||||||
token_replacer=token_replacer,
|
token_replacer=token_replacer,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
emb_normalize_mode=emb_normalize_mode,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
tes_text_embs.append(text_embeddings)
|
tes_text_embs.append(text_embeddings)
|
||||||
@@ -508,6 +516,7 @@ class PipelineLike:
|
|||||||
clip_skip=self.clip_skip,
|
clip_skip=self.clip_skip,
|
||||||
token_replacer=token_replacer,
|
token_replacer=token_replacer,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
emb_normalize_mode=emb_normalize_mode,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
tes_real_uncond_embs.append(real_uncond_embeddings)
|
tes_real_uncond_embs.append(real_uncond_embeddings)
|
||||||
@@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings(
|
|||||||
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2
|
# in sdxl, value of clip_skip is same for Text Encoder 1 and 2
|
||||||
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
||||||
text_embedding = enc_out["hidden_states"][-clip_skip]
|
text_embedding = enc_out["hidden_states"][-clip_skip]
|
||||||
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
||||||
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
||||||
if pool is None:
|
if pool is None:
|
||||||
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
|
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided
|
||||||
@@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings(
|
|||||||
else:
|
else:
|
||||||
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
||||||
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
||||||
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
if not is_sdxl: # SD 1.5 requires final_layer_norm
|
||||||
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
||||||
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
|
pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this
|
||||||
if pool is not None:
|
if pool is not None:
|
||||||
@@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings(
|
|||||||
clip_skip: int = 1,
|
clip_skip: int = 1,
|
||||||
token_replacer=None,
|
token_replacer=None,
|
||||||
device=None,
|
device=None,
|
||||||
|
emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none"
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||||
@@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings(
|
|||||||
# assign weights to the prompts and normalize in the sense of mean
|
# assign weights to the prompts and normalize in the sense of mean
|
||||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||||
# →全体でいいんじゃないかな
|
# →全体でいいんじゃないかな
|
||||||
|
|
||||||
if (not skip_parsing) and (not skip_weighting):
|
if (not skip_parsing) and (not skip_weighting):
|
||||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
if emb_normalize_mode == "abs":
|
||||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||||
if uncond_prompt is not None:
|
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
if uncond_prompt is not None:
|
||||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||||
|
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
elif emb_normalize_mode == "none":
|
||||||
|
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||||
|
|
||||||
|
else: # "original"
|
||||||
|
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||||
|
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||||
|
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||||
|
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
if uncond_prompt is not None:
|
||||||
|
previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||||
|
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||||
|
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||||
|
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
if uncond_prompt is not None:
|
if uncond_prompt is not None:
|
||||||
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
|
return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens
|
||||||
@@ -1427,6 +1455,27 @@ class BatchData(NamedTuple):
|
|||||||
ext: BatchDataExt
|
ext: BatchDataExt
|
||||||
|
|
||||||
|
|
||||||
|
class ListPrompter:
|
||||||
|
def __init__(self, prompts: List[str]):
|
||||||
|
self.prompts = prompts
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
def shuffle(self):
|
||||||
|
random.shuffle(self.prompts)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.prompts)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
if self.index >= len(self.prompts):
|
||||||
|
self.index = 0 # reset
|
||||||
|
return None
|
||||||
|
|
||||||
|
prompt = self.prompts[self.index]
|
||||||
|
self.index += 1
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
@@ -1951,15 +2000,35 @@ def main(args):
|
|||||||
token_embeds2[token_id] = embed
|
token_embeds2[token_id] = embed
|
||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
|
prompt_list = None
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
print(f"reading prompts from {args.from_file}")
|
||||||
with open(args.from_file, "r", encoding="utf-8") as f:
|
with open(args.from_file, "r", encoding="utf-8") as f:
|
||||||
prompt_list = f.read().splitlines()
|
prompt_list = f.read().splitlines()
|
||||||
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"]
|
||||||
|
prompter = ListPrompter(prompt_list)
|
||||||
|
|
||||||
|
elif args.from_module is not None:
|
||||||
|
|
||||||
|
def load_module_from_path(module_name, file_path):
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
print(f"reading prompts from module: {args.from_module}")
|
||||||
|
prompt_module = load_module_from_path("prompt_module", args.from_module)
|
||||||
|
|
||||||
|
prompter = prompt_module.get_prompter(args, pipe, networks)
|
||||||
|
|
||||||
elif args.prompt is not None:
|
elif args.prompt is not None:
|
||||||
prompt_list = [args.prompt]
|
prompter = ListPrompter([args.prompt])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
prompt_list = []
|
prompter = None # interactive mode
|
||||||
|
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
args.n_iter = 1
|
args.n_iter = 1
|
||||||
@@ -2026,14 +2095,16 @@ def main(args):
|
|||||||
mask_images = None
|
mask_images = None
|
||||||
|
|
||||||
# promptがないとき、画像のPngInfoから取得する
|
# promptがないとき、画像のPngInfoから取得する
|
||||||
if init_images is not None and len(prompt_list) == 0 and not args.interactive:
|
if init_images is not None and prompter is None and not args.interactive:
|
||||||
print("get prompts from images' metadata")
|
print("get prompts from images' metadata")
|
||||||
|
prompt_list = []
|
||||||
for img in init_images:
|
for img in init_images:
|
||||||
if "prompt" in img.text:
|
if "prompt" in img.text:
|
||||||
prompt = img.text["prompt"]
|
prompt = img.text["prompt"]
|
||||||
if "negative-prompt" in img.text:
|
if "negative-prompt" in img.text:
|
||||||
prompt += " --n " + img.text["negative-prompt"]
|
prompt += " --n " + img.text["negative-prompt"]
|
||||||
prompt_list.append(prompt)
|
prompt_list.append(prompt)
|
||||||
|
prompter = ListPrompter(prompt_list)
|
||||||
|
|
||||||
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
|
# プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する)
|
||||||
l = []
|
l = []
|
||||||
@@ -2105,15 +2176,18 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
guide_images = None
|
guide_images = None
|
||||||
|
|
||||||
# seed指定時はseedを決めておく
|
# 新しい乱数生成器を作成する
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
# dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう
|
if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1:
|
||||||
random.seed(args.seed)
|
# 引数のseedをそのまま使う
|
||||||
predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)]
|
def fixed_seed(*args, **kwargs):
|
||||||
if len(predefined_seeds) == 1:
|
return args.seed
|
||||||
predefined_seeds[0] = args.seed
|
|
||||||
|
seed_random = SimpleNamespace(randint=fixed_seed)
|
||||||
|
else:
|
||||||
|
seed_random = random.Random(args.seed)
|
||||||
else:
|
else:
|
||||||
predefined_seeds = None
|
seed_random = random.Random()
|
||||||
|
|
||||||
# デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み)
|
# デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み)
|
||||||
if args.W is None:
|
if args.W is None:
|
||||||
@@ -2127,11 +2201,14 @@ def main(args):
|
|||||||
|
|
||||||
for gen_iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
iter_seed = random.randint(0, 0x7FFFFFFF)
|
if args.iter_same_seed:
|
||||||
|
iter_seed = seed_random.randint(0, 2**32 - 1)
|
||||||
|
else:
|
||||||
|
iter_seed = None
|
||||||
|
|
||||||
# shuffle prompt list
|
# shuffle prompt list
|
||||||
if args.shuffle_prompts:
|
if args.shuffle_prompts:
|
||||||
random.shuffle(prompt_list)
|
prompter.shuffle()
|
||||||
|
|
||||||
# バッチ処理の関数
|
# バッチ処理の関数
|
||||||
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
||||||
@@ -2352,7 +2429,8 @@ def main(args):
|
|||||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||||
n.set_multiplier(m)
|
n.set_multiplier(m)
|
||||||
if regional_network:
|
if regional_network:
|
||||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
# TODO バッチから ds_ratio を取り出すべき
|
||||||
|
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio)
|
||||||
|
|
||||||
if not regional_network and network_pre_calc:
|
if not regional_network and network_pre_calc:
|
||||||
for n in networks:
|
for n in networks:
|
||||||
@@ -2386,6 +2464,7 @@ def main(args):
|
|||||||
return_latents=return_latents,
|
return_latents=return_latents,
|
||||||
clip_prompts=clip_prompts,
|
clip_prompts=clip_prompts,
|
||||||
clip_guide_images=guide_images,
|
clip_guide_images=guide_images,
|
||||||
|
emb_normalize_mode=args.emb_normalize_mode,
|
||||||
)
|
)
|
||||||
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
if highres_1st and not args.highres_fix_save_1st: # return images or latents
|
||||||
return images
|
return images
|
||||||
@@ -2451,8 +2530,8 @@ def main(args):
|
|||||||
prompt_index = 0
|
prompt_index = 0
|
||||||
global_step = 0
|
global_step = 0
|
||||||
batch_data = []
|
batch_data = []
|
||||||
while args.interactive or prompt_index < len(prompt_list):
|
while True:
|
||||||
if len(prompt_list) == 0:
|
if args.interactive:
|
||||||
# interactive
|
# interactive
|
||||||
valid = False
|
valid = False
|
||||||
while not valid:
|
while not valid:
|
||||||
@@ -2466,7 +2545,9 @@ def main(args):
|
|||||||
if not valid: # EOF, end app
|
if not valid: # EOF, end app
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raw_prompt = prompt_list[prompt_index]
|
raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step)
|
||||||
|
if raw_prompt is None:
|
||||||
|
break
|
||||||
|
|
||||||
# sd-dynamic-prompts like variants:
|
# sd-dynamic-prompts like variants:
|
||||||
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
# count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration)
|
||||||
@@ -2513,7 +2594,8 @@ def main(args):
|
|||||||
|
|
||||||
prompt_args = raw_prompt.strip().split(" --")
|
prompt_args = raw_prompt.strip().split(" --")
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
|
length = len(prompter) if hasattr(prompter, "__len__") else 0
|
||||||
|
print(f"prompt {prompt_index+1}/{length}: {prompt}")
|
||||||
|
|
||||||
for parg in prompt_args[1:]:
|
for parg in prompt_args[1:]:
|
||||||
try:
|
try:
|
||||||
@@ -2731,23 +2813,17 @@ def main(args):
|
|||||||
|
|
||||||
# prepare seed
|
# prepare seed
|
||||||
if seeds is not None: # given in prompt
|
if seeds is not None: # given in prompt
|
||||||
# 数が足りないなら前のをそのまま使う
|
# num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う
|
||||||
if len(seeds) > 0:
|
if len(seeds) > 0:
|
||||||
seed = seeds.pop(0)
|
seed = seeds.pop(0)
|
||||||
else:
|
else:
|
||||||
if predefined_seeds is not None:
|
if args.iter_same_seed:
|
||||||
if len(predefined_seeds) > 0:
|
seed = iter_seed
|
||||||
seed = predefined_seeds.pop(0)
|
|
||||||
else:
|
|
||||||
print("predefined seeds are exhausted")
|
|
||||||
seed = None
|
|
||||||
elif args.iter_same_seed:
|
|
||||||
seeds = iter_seed
|
|
||||||
else:
|
else:
|
||||||
seed = None # 前のを消す
|
seed = None # 前のを消す
|
||||||
|
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = random.randint(0, 0x7FFFFFFF)
|
seed = seed_random.randint(0, 2**32 - 1)
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
print(f"seed: {seed}")
|
print(f"seed: {seed}")
|
||||||
|
|
||||||
@@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
|
help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--from_module",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--interactive",
|
"--interactive",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--emb_normalize_mode",
|
||||||
|
type=str,
|
||||||
|
default="original",
|
||||||
|
choices=["original", "none", "abs"],
|
||||||
|
help="embedding normalization mode / embeddingの正規化モード",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
|
"--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,8 +12,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
@@ -248,7 +250,8 @@ class LoRAInfModule(LoRAModule):
|
|||||||
if mask is None:
|
if mask is None:
|
||||||
# raise ValueError(f"mask is None for resolution {area}")
|
# raise ValueError(f"mask is None for resolution {area}")
|
||||||
# emb_layers in SDXL doesn't have mask
|
# emb_layers in SDXL doesn't have mask
|
||||||
# logger.info(f"mask is None for resolution {area}, {x.size()}")
|
# if "emb" not in self.lora_name:
|
||||||
|
# print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}")
|
||||||
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1)
|
||||||
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts
|
||||||
if len(x.size()) != 4:
|
if len(x.size()) != 4:
|
||||||
@@ -265,7 +268,9 @@ class LoRAInfModule(LoRAModule):
|
|||||||
# apply mask for LoRA result
|
# apply mask for LoRA result
|
||||||
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
mask = self.get_mask_for_x(lx)
|
mask = self.get_mask_for_x(lx)
|
||||||
# logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}")
|
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||||
|
# if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked)
|
||||||
|
# mask = mask.squeeze(-1)
|
||||||
lx = lx * mask
|
lx = lx * mask
|
||||||
|
|
||||||
x = self.org_forward(x)
|
x = self.org_forward(x)
|
||||||
@@ -514,7 +519,9 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_dims) == num_total_blocks
|
len(block_dims) == num_total_blocks
|
||||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
logger.warning(
|
||||||
|
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||||
|
)
|
||||||
block_dims = [network_dim] * num_total_blocks
|
block_dims = [network_dim] * num_total_blocks
|
||||||
|
|
||||||
if block_alphas is not None:
|
if block_alphas is not None:
|
||||||
@@ -792,7 +799,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
logger.info(f"create LoRA network from block_dims")
|
logger.info(f"create LoRA network from block_dims")
|
||||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(
|
||||||
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
|
)
|
||||||
logger.info(f"block_dims: {block_dims}")
|
logger.info(f"block_dims: {block_dims}")
|
||||||
logger.info(f"block_alphas: {block_alphas}")
|
logger.info(f"block_alphas: {block_alphas}")
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
@@ -800,9 +809,13 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(
|
||||||
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
|
)
|
||||||
if self.conv_lora_dim is not None:
|
if self.conv_lora_dim is not None:
|
||||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
logger.info(
|
||||||
|
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||||
|
)
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -929,6 +942,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.multiplier = self.multiplier
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def set_enabled(self, is_enabled):
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.enabled = is_enabled
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@@ -1116,7 +1133,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
lora.set_network(self)
|
lora.set_network(self)
|
||||||
|
|
||||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_sub_prompts = num_sub_prompts
|
self.num_sub_prompts = num_sub_prompts
|
||||||
self.current_size = (height, width)
|
self.current_size = (height, width)
|
||||||
@@ -1142,6 +1159,13 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
resize_add(h, w)
|
resize_add(h, w)
|
||||||
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||||
resize_add(h + h % 2, w + w % 2)
|
resize_add(h + h % 2, w + w % 2)
|
||||||
|
|
||||||
|
# deep shrink
|
||||||
|
if ds_ratio is not None:
|
||||||
|
hd = int(h * ds_ratio)
|
||||||
|
wd = int(w * ds_ratio)
|
||||||
|
resize_add(hd, wd)
|
||||||
|
|
||||||
h = (h + 1) // 2
|
h = (h + 1) // 2
|
||||||
w = (w + 1) // 2
|
w = (w + 1) // 2
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user