mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support network mul from prompt
This commit is contained in:
@@ -47,7 +47,7 @@ VGG(
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||
import glob
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -60,7 +60,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import diffusers
|
||||
import numpy as np
|
||||
@@ -1817,6 +1816,34 @@ def preprocess_mask(mask):
|
||||
# return text_encoder
|
||||
|
||||
|
||||
class BatchDataBase(NamedTuple):
|
||||
# バッチ分割が必要ないデータ
|
||||
step: int
|
||||
prompt: str
|
||||
negative_prompt: str
|
||||
seed: int
|
||||
init_image: Any
|
||||
mask_image: Any
|
||||
clip_prompt: str
|
||||
guide_image: Any
|
||||
|
||||
|
||||
class BatchDataExt(NamedTuple):
|
||||
# バッチ分割が必要なデータ
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
scale: float
|
||||
negative_scale: float
|
||||
strength: float
|
||||
network_muls: Tuple[float]
|
||||
|
||||
|
||||
class BatchData(NamedTuple):
|
||||
base: BatchDataBase
|
||||
ext: BatchDataExt
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.fp16:
|
||||
dtype = torch.float16
|
||||
@@ -1995,11 +2022,13 @@ def main(args):
|
||||
# networkを組み込む
|
||||
if args.network_module:
|
||||
networks = []
|
||||
network_default_muls = []
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_default_muls.append(network_mul)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
@@ -2014,7 +2043,7 @@ def main(args):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
|
||||
if model_util.is_safetensors(network_weight):
|
||||
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||
from safetensors.torch import safe_open
|
||||
with safe_open(network_weight, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
@@ -2219,33 +2248,37 @@ def main(args):
|
||||
iter_seed = random.randint(0, 0x7fffffff)
|
||||
|
||||
# バッチ処理の関数
|
||||
def process_batch(batch, highres_fix, highres_1st=False):
|
||||
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
||||
batch_size = len(batch)
|
||||
|
||||
# highres_fixの処理
|
||||
if highres_fix and not highres_1st:
|
||||
# 1st stageのバッチを作成して呼び出す
|
||||
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||
print("process 1st stage1")
|
||||
batch_1st = []
|
||||
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
||||
for base, ext in batch:
|
||||
width_1st = int(width * args.highres_fix_scale + .5)
|
||||
height_1st = int(height * args.highres_fix_scale + .5)
|
||||
width_1st = width_1st - width_1st % 32
|
||||
height_1st = height_1st - height_1st % 32
|
||||
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
||||
|
||||
bd_1st = BatchData(base, BatchDataExt(width_1st, height_1st, args.highres_fix_steps,
|
||||
ext.scale, ext.negative_scale, ext.strength, ext.network_muls))
|
||||
batch_1st.append(bd_1st)
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
|
||||
# 2nd stageのバッチを作成して以下処理する
|
||||
print("process 2nd stage1")
|
||||
batch_2nd = []
|
||||
for i, (b1, image) in enumerate(zip(batch, images_1st)):
|
||||
image = image.resize((width, height), resample=PIL.Image.LANCZOS)
|
||||
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
|
||||
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
||||
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||
image = image.resize((width, height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:8]), bd.ext)
|
||||
batch_2nd.append(bd_2nd)
|
||||
batch = batch_2nd
|
||||
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
||||
height, steps, scale, negative_scale, strength) = batch[0]
|
||||
# このバッチの情報を取り出す
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||
|
||||
prompts = []
|
||||
@@ -2321,6 +2354,10 @@ def main(args):
|
||||
guide_images = guide_images[0]
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
|
||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||
if highres_1st and not args.highres_fix_save_1st:
|
||||
@@ -2398,6 +2435,7 @@ def main(args):
|
||||
strength = 0.8 if args.strength is None else args.strength
|
||||
negative_prompt = ""
|
||||
clip_prompt = None
|
||||
network_muls = None
|
||||
|
||||
prompt_args = prompt.strip().split(' --')
|
||||
prompt = prompt_args[0]
|
||||
@@ -2461,6 +2499,15 @@ def main(args):
|
||||
clip_prompt = m.group(1)
|
||||
print(f"clip prompt: {clip_prompt}")
|
||||
continue
|
||||
|
||||
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
||||
if m: # network multiplies
|
||||
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||
while len(network_muls) < len(networks):
|
||||
network_muls.append(network_muls[-1])
|
||||
print(f"network mul: {network_muls}")
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
@@ -2506,9 +2553,8 @@ def main(args):
|
||||
print("Use previous image as guide image.")
|
||||
guide_image = prev_image
|
||||
|
||||
# TODO named tupleか何かにする
|
||||
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
(width, height, steps, scale, negative_scale, strength))
|
||||
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
||||
process_batch(batch_data, highres_fix)
|
||||
batch_data.clear()
|
||||
@@ -2578,12 +2624,15 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--opt_channels_last", action='store_true',
|
||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
||||
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
help='additiona network module to use / 追加ネットワークを使う時そのモジュール名')
|
||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
help='additiona network weights to load / 追加ネットワークの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
||||
help='additiona network multiplier / 追加ネットワークの効果の倍率')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--network_show_meta", action='store_true',
|
||||
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
||||
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
|
||||
Reference in New Issue
Block a user