From cbe9c5dc068cd81c5f7d53c7aeba601d5241e7f9 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 12 Feb 2024 14:17:27 +0900 Subject: [PATCH] supprt deep shink with regional lora, add prompter module --- gen_img.py | 168 ++++++++++++++++++++++++++++++++++++----------- networks/lora.py | 38 +++++++++-- 2 files changed, 161 insertions(+), 45 deletions(-) diff --git a/gen_img.py b/gen_img.py index f5a740c3..a24220a0 100644 --- a/gen_img.py +++ b/gen_img.py @@ -3,6 +3,8 @@ import json from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob import importlib +import importlib.util +import sys import inspect import time import zipfile @@ -333,6 +335,10 @@ class PipelineLike: self.scheduler = scheduler self.safety_checker = None + self.clip_vision_model: CLIPVisionModelWithProjection = None + self.clip_vision_processor: CLIPImageProcessor = None + self.clip_vision_strength = 0.0 + # Textual Inversion self.token_replacements_list = [] for _ in range(len(self.text_encoders)): @@ -419,6 +425,7 @@ class PipelineLike: callback_steps: Optional[int] = 1, img2img_noise=None, clip_guide_images=None, + emb_normalize_mode: str = "original", **kwargs, ): # TODO support secondary prompt @@ -493,6 +500,7 @@ class PipelineLike: clip_skip=self.clip_skip, token_replacer=token_replacer, device=self.device, + emb_normalize_mode=emb_normalize_mode, **kwargs, ) tes_text_embs.append(text_embeddings) @@ -508,6 +516,7 @@ class PipelineLike: clip_skip=self.clip_skip, token_replacer=token_replacer, device=self.device, + emb_normalize_mode=emb_normalize_mode, **kwargs, ) tes_real_uncond_embs.append(real_uncond_embeddings) @@ -1099,7 +1108,7 @@ def get_unweighted_text_embeddings( # in sdxl, value of clip_skip is same for Text Encoder 1 and 2 enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True) text_embedding = enc_out["hidden_states"][-clip_skip] - if not is_sdxl: # SD 1.5 requires final_layer_norm + if not is_sdxl: # SD 1.5 requires final_layer_norm text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) if pool is None: pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided @@ -1122,7 +1131,7 @@ def get_unweighted_text_embeddings( else: enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) text_embeddings = enc_out["hidden_states"][-clip_skip] - if not is_sdxl: # SD 1.5 requires final_layer_norm + if not is_sdxl: # SD 1.5 requires final_layer_norm text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) pool = enc_out.get("text_embeds", None) # text encoder 1 doesn't return this if pool is not None: @@ -1143,6 +1152,7 @@ def get_weighted_text_embeddings( clip_skip: int = 1, token_replacer=None, device=None, + emb_normalize_mode: Optional[str] = "original", # "original", "abs", "none" **kwargs, ): max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 @@ -1239,16 +1249,34 @@ def get_weighted_text_embeddings( # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? # →全体でいいんじゃないかな + if (not skip_parsing) and (not skip_weighting): - previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) - text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - if uncond_prompt is not None: - previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) - uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if emb_normalize_mode == "abs": + previous_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().abs().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().abs().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + elif emb_normalize_mode == "none": + text_embeddings *= prompt_weights.unsqueeze(-1) + if uncond_prompt is not None: + uncond_embeddings *= uncond_weights.unsqueeze(-1) + + else: # "original" + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) if uncond_prompt is not None: return text_embeddings, text_pool, uncond_embeddings, uncond_pool, prompt_tokens @@ -1427,6 +1455,27 @@ class BatchData(NamedTuple): ext: BatchDataExt +class ListPrompter: + def __init__(self, prompts: List[str]): + self.prompts = prompts + self.index = 0 + + def shuffle(self): + random.shuffle(self.prompts) + + def __len__(self): + return len(self.prompts) + + def __call__(self, *args, **kwargs): + if self.index >= len(self.prompts): + self.index = 0 # reset + return None + + prompt = self.prompts[self.index] + self.index += 1 + return prompt + + def main(args): if args.fp16: dtype = torch.float16 @@ -1951,15 +2000,35 @@ def main(args): token_embeds2[token_id] = embed # promptを取得する + prompt_list = None if args.from_file is not None: print(f"reading prompts from {args.from_file}") with open(args.from_file, "r", encoding="utf-8") as f: prompt_list = f.read().splitlines() prompt_list = [d for d in prompt_list if len(d.strip()) > 0 and d[0] != "#"] + prompter = ListPrompter(prompt_list) + + elif args.from_module is not None: + + def load_module_from_path(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Module '{module_name}' cannot be loaded from '{file_path}'") + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + print(f"reading prompts from module: {args.from_module}") + prompt_module = load_module_from_path("prompt_module", args.from_module) + + prompter = prompt_module.get_prompter(args, pipe, networks) + elif args.prompt is not None: - prompt_list = [args.prompt] + prompter = ListPrompter([args.prompt]) + else: - prompt_list = [] + prompter = None # interactive mode if args.interactive: args.n_iter = 1 @@ -2026,14 +2095,16 @@ def main(args): mask_images = None # promptがないとき、画像のPngInfoから取得する - if init_images is not None and len(prompt_list) == 0 and not args.interactive: + if init_images is not None and prompter is None and not args.interactive: print("get prompts from images' metadata") + prompt_list = [] for img in init_images: if "prompt" in img.text: prompt = img.text["prompt"] if "negative-prompt" in img.text: prompt += " --n " + img.text["negative-prompt"] prompt_list.append(prompt) + prompter = ListPrompter(prompt_list) # プロンプトと画像を一致させるため指定回数だけ繰り返す(画像を増幅する) l = [] @@ -2105,15 +2176,18 @@ def main(args): else: guide_images = None - # seed指定時はseedを決めておく + # 新しい乱数生成器を作成する if args.seed is not None: - # dynamic promptを使うと足りなくなる→images_per_promptを適当に大きくしておいてもらう - random.seed(args.seed) - predefined_seeds = [random.randint(0, 0x7FFFFFFF) for _ in range(args.n_iter * len(prompt_list) * args.images_per_prompt)] - if len(predefined_seeds) == 1: - predefined_seeds[0] = args.seed + if prompt_list and len(prompt_list) == 1 and args.images_per_prompt == 1: + # 引数のseedをそのまま使う + def fixed_seed(*args, **kwargs): + return args.seed + + seed_random = SimpleNamespace(randint=fixed_seed) + else: + seed_random = random.Random(args.seed) else: - predefined_seeds = None + seed_random = random.Random() # デフォルト画像サイズを設定する:img2imgではこれらの値は無視される(またはW*Hにリサイズ済み) if args.W is None: @@ -2127,11 +2201,14 @@ def main(args): for gen_iter in range(args.n_iter): print(f"iteration {gen_iter+1}/{args.n_iter}") - iter_seed = random.randint(0, 0x7FFFFFFF) + if args.iter_same_seed: + iter_seed = seed_random.randint(0, 2**32 - 1) + else: + iter_seed = None # shuffle prompt list if args.shuffle_prompts: - random.shuffle(prompt_list) + prompter.shuffle() # バッチ処理の関数 def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): @@ -2352,7 +2429,8 @@ def main(args): for n, m in zip(networks, network_muls if network_muls else network_default_muls): n.set_multiplier(m) if regional_network: - n.set_current_generation(batch_size, num_sub_prompts, width, height, shared) + # TODO バッチから ds_ratio を取り出すべき + n.set_current_generation(batch_size, num_sub_prompts, width, height, shared, unet.ds_ratio) if not regional_network and network_pre_calc: for n in networks: @@ -2386,6 +2464,7 @@ def main(args): return_latents=return_latents, clip_prompts=clip_prompts, clip_guide_images=guide_images, + emb_normalize_mode=args.emb_normalize_mode, ) if highres_1st and not args.highres_fix_save_1st: # return images or latents return images @@ -2451,8 +2530,8 @@ def main(args): prompt_index = 0 global_step = 0 batch_data = [] - while args.interactive or prompt_index < len(prompt_list): - if len(prompt_list) == 0: + while True: + if args.interactive: # interactive valid = False while not valid: @@ -2466,7 +2545,9 @@ def main(args): if not valid: # EOF, end app break else: - raw_prompt = prompt_list[prompt_index] + raw_prompt = prompter(args, pipe, seed_random, iter_seed, prompt_index, global_step) + if raw_prompt is None: + break # sd-dynamic-prompts like variants: # count is 1 (not dynamic) or images_per_prompt (no enumeration) or arbitrary (enumeration) @@ -2513,7 +2594,8 @@ def main(args): prompt_args = raw_prompt.strip().split(" --") prompt = prompt_args[0] - print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}") + length = len(prompter) if hasattr(prompter, "__len__") else 0 + print(f"prompt {prompt_index+1}/{length}: {prompt}") for parg in prompt_args[1:]: try: @@ -2731,23 +2813,17 @@ def main(args): # prepare seed if seeds is not None: # given in prompt - # 数が足りないなら前のをそのまま使う + # num_images_per_promptが多い場合は足りなくなるので、足りない分は前のを使う if len(seeds) > 0: seed = seeds.pop(0) else: - if predefined_seeds is not None: - if len(predefined_seeds) > 0: - seed = predefined_seeds.pop(0) - else: - print("predefined seeds are exhausted") - seed = None - elif args.iter_same_seed: - seeds = iter_seed + if args.iter_same_seed: + seed = iter_seed else: seed = None # 前のを消す if seed is None: - seed = random.randint(0, 0x7FFFFFFF) + seed = seed_random.randint(0, 2**32 - 1) if args.interactive: print(f"seed: {seed}") @@ -2853,6 +2929,15 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="if specified, load prompts from this file / 指定時はプロンプトをファイルから読み込む", ) + parser.add_argument( + "--from_module", + type=str, + default=None, + help="if specified, load prompts from this module / 指定時はプロンプトをモジュールから読み込む", + ) + parser.add_argument( + "--prompter_module_args", type=str, default=None, help="args for prompter module / prompterモジュールの引数" + ) parser.add_argument( "--interactive", action="store_true", @@ -3067,6 +3152,13 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="max embedding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる", ) + parser.add_argument( + "--emb_normalize_mode", + type=str, + default="original", + choices=["original", "none", "abs"], + help="embedding normalization mode / embeddingの正規化モード", + ) parser.add_argument( "--guide_image_path", type=str, default=None, nargs="*", help="image to ControlNet / ControlNetでガイドに使う画像" ) diff --git a/networks/lora.py b/networks/lora.py index eaf656ac..948b30b0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,8 +12,10 @@ import numpy as np import torch import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -248,7 +250,8 @@ class LoRAInfModule(LoRAModule): if mask is None: # raise ValueError(f"mask is None for resolution {area}") # emb_layers in SDXL doesn't have mask - # logger.info(f"mask is None for resolution {area}, {x.size()}") + # if "emb" not in self.lora_name: + # print(f"mask is None for resolution {self.lora_name}, {area}, {x.size()}") mask_size = (1, x.size()[1]) if len(x.size()) == 2 else (1, *x.size()[1:-1], 1) return torch.ones(mask_size, dtype=x.dtype, device=x.device) / self.network.num_sub_prompts if len(x.size()) != 4: @@ -265,7 +268,9 @@ class LoRAInfModule(LoRAModule): # apply mask for LoRA result lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale mask = self.get_mask_for_x(lx) - # logger.info(f"regional {self.lora_name} {self.network.sub_prompt_index} {lx.size()} {mask.size()}") + # print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size()) + # if mask.ndim > lx.ndim: # in some resolution, lx is 2d and mask is 3d (the reason is not checked) + # mask = mask.squeeze(-1) lx = lx * mask x = self.org_forward(x) @@ -514,7 +519,9 @@ def get_block_dims_and_alphas( len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -792,7 +799,9 @@ class LoRANetwork(torch.nn.Module): logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -800,9 +809,13 @@ class LoRANetwork(torch.nn.Module): logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -929,6 +942,10 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.multiplier = self.multiplier + def set_enabled(self, is_enabled): + for lora in self.text_encoder_loras + self.unet_loras: + lora.enabled = is_enabled + def load_weights(self, file): if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file @@ -1116,7 +1133,7 @@ class LoRANetwork(torch.nn.Module): for lora in self.text_encoder_loras + self.unet_loras: lora.set_network(self) - def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared): + def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared, ds_ratio=None): self.batch_size = batch_size self.num_sub_prompts = num_sub_prompts self.current_size = (height, width) @@ -1142,6 +1159,13 @@ class LoRANetwork(torch.nn.Module): resize_add(h, w) if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2 resize_add(h + h % 2, w + w % 2) + + # deep shrink + if ds_ratio is not None: + hd = int(h * ds_ratio) + wd = int(w * ds_ratio) + resize_add(hd, wd) + h = (h + 1) // 2 w = (w + 1) // 2