support network mul from prompt

This commit is contained in:
Kohya S
2023-02-19 18:43:35 +09:00
parent e45e272e9d
commit d94c0d70fe
2 changed files with 73 additions and 19 deletions

View File

@@ -47,7 +47,7 @@ VGG(
""" """
import json import json
from typing import List, Optional, Union from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
import glob import glob
import importlib import importlib
import inspect import inspect
@@ -60,7 +60,6 @@ import math
import os import os
import random import random
import re import re
from typing import Any, Callable, List, Optional, Union
import diffusers import diffusers
import numpy as np import numpy as np
@@ -1817,6 +1816,34 @@ def preprocess_mask(mask):
# return text_encoder # return text_encoder
class BatchDataBase(NamedTuple):
# バッチ分割が必要ないデータ
step: int
prompt: str
negative_prompt: str
seed: int
init_image: Any
mask_image: Any
clip_prompt: str
guide_image: Any
class BatchDataExt(NamedTuple):
# バッチ分割が必要なデータ
width: int
height: int
steps: int
scale: float
negative_scale: float
strength: float
network_muls: Tuple[float]
class BatchData(NamedTuple):
base: BatchDataBase
ext: BatchDataExt
def main(args): def main(args):
if args.fp16: if args.fp16:
dtype = torch.float16 dtype = torch.float16
@@ -1995,11 +2022,13 @@ def main(args):
# networkを組み込む # networkを組み込む
if args.network_module: if args.network_module:
networks = [] networks = []
network_default_muls = []
for i, network_module in enumerate(args.network_module): for i, network_module in enumerate(args.network_module):
print("import network module:", network_module) print("import network module:", network_module)
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_default_muls.append(network_mul)
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -2014,7 +2043,7 @@ def main(args):
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) print("load network weights from:", network_weight)
if model_util.is_safetensors(network_weight): if model_util.is_safetensors(network_weight) and args.network_show_meta:
from safetensors.torch import safe_open from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
@@ -2219,33 +2248,37 @@ def main(args):
iter_seed = random.randint(0, 0x7fffffff) iter_seed = random.randint(0, 0x7fffffff)
# バッチ処理の関数 # バッチ処理の関数
def process_batch(batch, highres_fix, highres_1st=False): def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
batch_size = len(batch) batch_size = len(batch)
# highres_fixの処理 # highres_fixの処理
if highres_fix and not highres_1st: if highres_fix and not highres_1st:
# 1st stageのバッチを作成して呼び出す # 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
print("process 1st stage1") print("process 1st stage1")
batch_1st = [] batch_1st = []
for params1, (width, height, steps, scale, negative_scale, strength) in batch: for base, ext in batch:
width_1st = int(width * args.highres_fix_scale + .5) width_1st = int(width * args.highres_fix_scale + .5)
height_1st = int(height * args.highres_fix_scale + .5) height_1st = int(height * args.highres_fix_scale + .5)
width_1st = width_1st - width_1st % 32 width_1st = width_1st - width_1st % 32
height_1st = height_1st - height_1st % 32 height_1st = height_1st - height_1st % 32
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
bd_1st = BatchData(base, BatchDataExt(width_1st, height_1st, args.highres_fix_steps,
ext.scale, ext.negative_scale, ext.strength, ext.network_muls))
batch_1st.append(bd_1st)
images_1st = process_batch(batch_1st, True, True) images_1st = process_batch(batch_1st, True, True)
# 2nd stageのバッチを作成して以下処理する # 2nd stageのバッチを作成して以下処理する
print("process 2nd stage1") print("process 2nd stage1")
batch_2nd = [] batch_2nd = []
for i, (b1, image) in enumerate(zip(batch, images_1st)): for i, (bd, image) in enumerate(zip(batch, images_1st)):
image = image.resize((width, height), resample=PIL.Image.LANCZOS) image = image.resize((width, height), resample=PIL.Image.LANCZOS) # img2imgとして設定
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1 bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:8]), bd.ext)
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2)) batch_2nd.append(bd_2nd)
batch = batch_2nd batch = batch_2nd
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width, # このバッチの情報を取り出す
height, steps, scale, negative_scale, strength) = batch[0] (step_first, _, _, _, init_image, mask_image, _, guide_image), \
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR) noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
prompts = [] prompts = []
@@ -2321,6 +2354,10 @@ def main(args):
guide_images = guide_images[0] guide_images = guide_images[0]
# generate # generate
if networks:
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
n.set_multiplier(m)
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code, images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0] output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
if highres_1st and not args.highres_fix_save_1st: if highres_1st and not args.highres_fix_save_1st:
@@ -2398,6 +2435,7 @@ def main(args):
strength = 0.8 if args.strength is None else args.strength strength = 0.8 if args.strength is None else args.strength
negative_prompt = "" negative_prompt = ""
clip_prompt = None clip_prompt = None
network_muls = None
prompt_args = prompt.strip().split(' --') prompt_args = prompt.strip().split(' --')
prompt = prompt_args[0] prompt = prompt_args[0]
@@ -2461,6 +2499,15 @@ def main(args):
clip_prompt = m.group(1) clip_prompt = m.group(1)
print(f"clip prompt: {clip_prompt}") print(f"clip prompt: {clip_prompt}")
continue continue
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
if m: # network multiplies
network_muls = [float(v) for v in m.group(1).split(",")]
while len(network_muls) < len(networks):
network_muls.append(network_muls[-1])
print(f"network mul: {network_muls}")
continue
except ValueError as ex: except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}") print(f"Exception in parsing / 解析エラー: {parg}")
print(ex) print(ex)
@@ -2506,9 +2553,8 @@ def main(args):
print("Use previous image as guide image.") print("Use previous image as guide image.")
guide_image = prev_image guide_image = prev_image
# TODO named tupleか何かにする b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image), BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
(width, height, steps, scale, negative_scale, strength))
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要? if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
process_batch(batch_data, highres_fix) process_batch(batch_data, highres_fix)
batch_data.clear() batch_data.clear()
@@ -2578,12 +2624,15 @@ if __name__ == '__main__':
parser.add_argument("--opt_channels_last", action='store_true', parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannels lastを指定し最適化する') help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*', parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') help='additiona network module to use / 追加ネットワークを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*', parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み') help='additiona network weights to load / 追加ネットワークの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_mul", type=float, default=None, nargs='*',
help='additiona network multiplier / 追加ネットワークの効果の倍率')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_show_meta", action='store_true',
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*', parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings') help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')

View File

@@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier
def load_weights(self, file): def load_weights(self, file):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open from safetensors.torch import load_file, safe_open