From d94c0d70fe05626fe1d3d8be0628ed89a16365ec Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 Feb 2023 18:43:35 +0900 Subject: [PATCH] support network mul from prompt --- gen_img_diffusers.py | 87 ++++++++++++++++++++++++++++++++++---------- networks/lora.py | 5 +++ 2 files changed, 73 insertions(+), 19 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 25a5b2d9..be92e99e 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -47,7 +47,7 @@ VGG( """ import json -from typing import List, Optional, Union +from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable import glob import importlib import inspect @@ -60,7 +60,6 @@ import math import os import random import re -from typing import Any, Callable, List, Optional, Union import diffusers import numpy as np @@ -1817,6 +1816,34 @@ def preprocess_mask(mask): # return text_encoder +class BatchDataBase(NamedTuple): + # バッチ分割が必要ないデータ + step: int + prompt: str + negative_prompt: str + seed: int + init_image: Any + mask_image: Any + clip_prompt: str + guide_image: Any + + +class BatchDataExt(NamedTuple): + # バッチ分割が必要なデータ + width: int + height: int + steps: int + scale: float + negative_scale: float + strength: float + network_muls: Tuple[float] + + +class BatchData(NamedTuple): + base: BatchDataBase + ext: BatchDataExt + + def main(args): if args.fp16: dtype = torch.float16 @@ -1995,11 +2022,13 @@ def main(args): # networkを組み込む if args.network_module: networks = [] + network_default_muls = [] for i, network_module in enumerate(args.network_module): print("import network module:", network_module) imported_module = importlib.import_module(network_module) network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] + network_default_muls.append(network_mul) net_kwargs = {} if args.network_args and i < len(args.network_args): @@ -2014,7 +2043,7 @@ def main(args): network_weight = args.network_weights[i] print("load network weights from:", network_weight) - if model_util.is_safetensors(network_weight): + if model_util.is_safetensors(network_weight) and args.network_show_meta: from safetensors.torch import safe_open with safe_open(network_weight, framework="pt") as f: metadata = f.metadata() @@ -2219,33 +2248,37 @@ def main(args): iter_seed = random.randint(0, 0x7fffffff) # バッチ処理の関数 - def process_batch(batch, highres_fix, highres_1st=False): + def process_batch(batch: List[BatchData], highres_fix, highres_1st=False): batch_size = len(batch) # highres_fixの処理 if highres_fix and not highres_1st: - # 1st stageのバッチを作成して呼び出す + # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す print("process 1st stage1") batch_1st = [] - for params1, (width, height, steps, scale, negative_scale, strength) in batch: + for base, ext in batch: width_1st = int(width * args.highres_fix_scale + .5) height_1st = int(height * args.highres_fix_scale + .5) width_1st = width_1st - width_1st % 32 height_1st = height_1st - height_1st % 32 - batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength))) + + bd_1st = BatchData(base, BatchDataExt(width_1st, height_1st, args.highres_fix_steps, + ext.scale, ext.negative_scale, ext.strength, ext.network_muls)) + batch_1st.append(bd_1st) images_1st = process_batch(batch_1st, True, True) # 2nd stageのバッチを作成して以下処理する print("process 2nd stage1") batch_2nd = [] - for i, (b1, image) in enumerate(zip(batch, images_1st)): - image = image.resize((width, height), resample=PIL.Image.LANCZOS) - (step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1 - batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) + for i, (bd, image) in enumerate(zip(batch, images_1st)): + image = image.resize((width, height), resample=PIL.Image.LANCZOS) # img2imgとして設定 + bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:8]), bd.ext) + batch_2nd.append(bd_2nd) batch = batch_2nd - (step_first, _, _, _, init_image, mask_image, _, guide_image), (width, - height, steps, scale, negative_scale, strength) = batch[0] + # このバッチの情報を取り出す + (step_first, _, _, _, init_image, mask_image, _, guide_image), \ + (width, height, steps, scale, negative_scale, strength, network_muls) = batch[0] noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) prompts = [] @@ -2321,6 +2354,10 @@ def main(args): guide_images = guide_images[0] # generate + if networks: + for n, m in zip(networks, network_muls if network_muls else network_default_muls): + n.set_multiplier(m) + images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] if highres_1st and not args.highres_fix_save_1st: @@ -2398,6 +2435,7 @@ def main(args): strength = 0.8 if args.strength is None else args.strength negative_prompt = "" clip_prompt = None + network_muls = None prompt_args = prompt.strip().split(' --') prompt = prompt_args[0] @@ -2461,6 +2499,15 @@ def main(args): clip_prompt = m.group(1) print(f"clip prompt: {clip_prompt}") continue + + m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE) + if m: # network multiplies + network_muls = [float(v) for v in m.group(1).split(",")] + while len(network_muls) < len(networks): + network_muls.append(network_muls[-1]) + print(f"network mul: {network_muls}") + continue + except ValueError as ex: print(f"Exception in parsing / 解析エラー: {parg}") print(ex) @@ -2506,9 +2553,8 @@ def main(args): print("Use previous image as guide image.") guide_image = prev_image - # TODO named tupleか何かにする - b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), - (width, height, steps, scale, negative_scale, strength)) + b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), + BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None)) if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? process_batch(batch_data, highres_fix) batch_data.clear() @@ -2578,12 +2624,15 @@ if __name__ == '__main__': parser.add_argument("--opt_channels_last", action='store_true', help='set channels last option to model / モデルにchannels lastを指定し最適化する') parser.add_argument("--network_module", type=str, default=None, nargs='*', - help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') + help='additiona network module to use / 追加ネットワークを使う時そのモジュール名') parser.add_argument("--network_weights", type=str, default=None, nargs='*', - help='Hypernetwork weights to load / Hypernetworkの重み') - parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') + help='additiona network weights to load / 追加ネットワークの重み') + parser.add_argument("--network_mul", type=float, default=None, nargs='*', + help='additiona network multiplier / 追加ネットワークの効果の倍率') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') + parser.add_argument("--network_show_meta", action='store_true', + help='show metadata of network model / ネットワークモデルのメタデータを表示する') parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') diff --git a/networks/lora.py b/networks/lora.py index a1f38c16..24b107ba 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module): assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + def load_weights(self, file): if os.path.splitext(file)[1] == '.safetensors': from safetensors.torch import load_file, safe_open