mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support network mul from prompt
This commit is contained in:
@@ -47,7 +47,7 @@ VGG(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import List, Optional, Union
|
from typing import Any, List, NamedTuple, Optional, Tuple, Union, Callable
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
@@ -60,7 +60,6 @@ import math
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
from typing import Any, Callable, List, Optional, Union
|
|
||||||
|
|
||||||
import diffusers
|
import diffusers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -1817,6 +1816,34 @@ def preprocess_mask(mask):
|
|||||||
# return text_encoder
|
# return text_encoder
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDataBase(NamedTuple):
|
||||||
|
# バッチ分割が必要ないデータ
|
||||||
|
step: int
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: str
|
||||||
|
seed: int
|
||||||
|
init_image: Any
|
||||||
|
mask_image: Any
|
||||||
|
clip_prompt: str
|
||||||
|
guide_image: Any
|
||||||
|
|
||||||
|
|
||||||
|
class BatchDataExt(NamedTuple):
|
||||||
|
# バッチ分割が必要なデータ
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
steps: int
|
||||||
|
scale: float
|
||||||
|
negative_scale: float
|
||||||
|
strength: float
|
||||||
|
network_muls: Tuple[float]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchData(NamedTuple):
|
||||||
|
base: BatchDataBase
|
||||||
|
ext: BatchDataExt
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
@@ -1995,11 +2022,13 @@ def main(args):
|
|||||||
# networkを組み込む
|
# networkを組み込む
|
||||||
if args.network_module:
|
if args.network_module:
|
||||||
networks = []
|
networks = []
|
||||||
|
network_default_muls = []
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
print("import network module:", network_module)
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
|
|
||||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||||
|
network_default_muls.append(network_mul)
|
||||||
|
|
||||||
net_kwargs = {}
|
net_kwargs = {}
|
||||||
if args.network_args and i < len(args.network_args):
|
if args.network_args and i < len(args.network_args):
|
||||||
@@ -2014,7 +2043,7 @@ def main(args):
|
|||||||
network_weight = args.network_weights[i]
|
network_weight = args.network_weights[i]
|
||||||
print("load network weights from:", network_weight)
|
print("load network weights from:", network_weight)
|
||||||
|
|
||||||
if model_util.is_safetensors(network_weight):
|
if model_util.is_safetensors(network_weight) and args.network_show_meta:
|
||||||
from safetensors.torch import safe_open
|
from safetensors.torch import safe_open
|
||||||
with safe_open(network_weight, framework="pt") as f:
|
with safe_open(network_weight, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
@@ -2219,33 +2248,37 @@ def main(args):
|
|||||||
iter_seed = random.randint(0, 0x7fffffff)
|
iter_seed = random.randint(0, 0x7fffffff)
|
||||||
|
|
||||||
# バッチ処理の関数
|
# バッチ処理の関数
|
||||||
def process_batch(batch, highres_fix, highres_1st=False):
|
def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
|
||||||
batch_size = len(batch)
|
batch_size = len(batch)
|
||||||
|
|
||||||
# highres_fixの処理
|
# highres_fixの処理
|
||||||
if highres_fix and not highres_1st:
|
if highres_fix and not highres_1st:
|
||||||
# 1st stageのバッチを作成して呼び出す
|
# 1st stageのバッチを作成して呼び出す:サイズを小さくして呼び出す
|
||||||
print("process 1st stage1")
|
print("process 1st stage1")
|
||||||
batch_1st = []
|
batch_1st = []
|
||||||
for params1, (width, height, steps, scale, negative_scale, strength) in batch:
|
for base, ext in batch:
|
||||||
width_1st = int(width * args.highres_fix_scale + .5)
|
width_1st = int(width * args.highres_fix_scale + .5)
|
||||||
height_1st = int(height * args.highres_fix_scale + .5)
|
height_1st = int(height * args.highres_fix_scale + .5)
|
||||||
width_1st = width_1st - width_1st % 32
|
width_1st = width_1st - width_1st % 32
|
||||||
height_1st = height_1st - height_1st % 32
|
height_1st = height_1st - height_1st % 32
|
||||||
batch_1st.append((params1, (width_1st, height_1st, args.highres_fix_steps, scale, negative_scale, strength)))
|
|
||||||
|
bd_1st = BatchData(base, BatchDataExt(width_1st, height_1st, args.highres_fix_steps,
|
||||||
|
ext.scale, ext.negative_scale, ext.strength, ext.network_muls))
|
||||||
|
batch_1st.append(bd_1st)
|
||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
|
|
||||||
# 2nd stageのバッチを作成して以下処理する
|
# 2nd stageのバッチを作成して以下処理する
|
||||||
print("process 2nd stage1")
|
print("process 2nd stage1")
|
||||||
batch_2nd = []
|
batch_2nd = []
|
||||||
for i, (b1, image) in enumerate(zip(batch, images_1st)):
|
for i, (bd, image) in enumerate(zip(batch, images_1st)):
|
||||||
image = image.resize((width, height), resample=PIL.Image.LANCZOS)
|
image = image.resize((width, height), resample=PIL.Image.LANCZOS) # img2imgとして設定
|
||||||
(step, prompt, negative_prompt, seed, _, _, clip_prompt, guide_image), params2 = b1
|
bd_2nd = BatchData(BatchDataBase(*bd.base[0:3], bd.base.seed+1, image, None, *bd.base[6:8]), bd.ext)
|
||||||
batch_2nd.append(((step, prompt, negative_prompt, seed+1, image, None, clip_prompt, guide_image), params2))
|
batch_2nd.append(bd_2nd)
|
||||||
batch = batch_2nd
|
batch = batch_2nd
|
||||||
|
|
||||||
(step_first, _, _, _, init_image, mask_image, _, guide_image), (width,
|
# このバッチの情報を取り出す
|
||||||
height, steps, scale, negative_scale, strength) = batch[0]
|
(step_first, _, _, _, init_image, mask_image, _, guide_image), \
|
||||||
|
(width, height, steps, scale, negative_scale, strength, network_muls) = batch[0]
|
||||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||||
|
|
||||||
prompts = []
|
prompts = []
|
||||||
@@ -2321,6 +2354,10 @@ def main(args):
|
|||||||
guide_images = guide_images[0]
|
guide_images = guide_images[0]
|
||||||
|
|
||||||
# generate
|
# generate
|
||||||
|
if networks:
|
||||||
|
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||||
|
n.set_multiplier(m)
|
||||||
|
|
||||||
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
images = pipe(prompts, negative_prompts, init_images, mask_images, height, width, steps, scale, negative_scale, strength, latents=start_code,
|
||||||
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
output_type='pil', max_embeddings_multiples=max_embeddings_multiples, img2img_noise=i2i_noises, clip_prompts=clip_prompts, clip_guide_images=guide_images)[0]
|
||||||
if highres_1st and not args.highres_fix_save_1st:
|
if highres_1st and not args.highres_fix_save_1st:
|
||||||
@@ -2398,6 +2435,7 @@ def main(args):
|
|||||||
strength = 0.8 if args.strength is None else args.strength
|
strength = 0.8 if args.strength is None else args.strength
|
||||||
negative_prompt = ""
|
negative_prompt = ""
|
||||||
clip_prompt = None
|
clip_prompt = None
|
||||||
|
network_muls = None
|
||||||
|
|
||||||
prompt_args = prompt.strip().split(' --')
|
prompt_args = prompt.strip().split(' --')
|
||||||
prompt = prompt_args[0]
|
prompt = prompt_args[0]
|
||||||
@@ -2461,6 +2499,15 @@ def main(args):
|
|||||||
clip_prompt = m.group(1)
|
clip_prompt = m.group(1)
|
||||||
print(f"clip prompt: {clip_prompt}")
|
print(f"clip prompt: {clip_prompt}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
m = re.match(r'am ([\d\.\-,]+)', parg, re.IGNORECASE)
|
||||||
|
if m: # network multiplies
|
||||||
|
network_muls = [float(v) for v in m.group(1).split(",")]
|
||||||
|
while len(network_muls) < len(networks):
|
||||||
|
network_muls.append(network_muls[-1])
|
||||||
|
print(f"network mul: {network_muls}")
|
||||||
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||||
print(ex)
|
print(ex)
|
||||||
@@ -2506,9 +2553,8 @@ def main(args):
|
|||||||
print("Use previous image as guide image.")
|
print("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
# TODO named tupleか何かにする
|
b1 = BatchData(BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||||
b1 = ((global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
BatchDataExt(width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None))
|
||||||
(width, height, steps, scale, negative_scale, strength))
|
|
||||||
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
if len(batch_data) > 0 and batch_data[-1][1] != b1[1]: # バッチ分割必要?
|
||||||
process_batch(batch_data, highres_fix)
|
process_batch(batch_data, highres_fix)
|
||||||
batch_data.clear()
|
batch_data.clear()
|
||||||
@@ -2578,12 +2624,15 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--opt_channels_last", action='store_true',
|
parser.add_argument("--opt_channels_last", action='store_true',
|
||||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
||||||
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||||
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
help='additiona network module to use / 追加ネットワークを使う時そのモジュール名')
|
||||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
help='additiona network weights to load / 追加ネットワークの重み')
|
||||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*',
|
||||||
|
help='additiona network multiplier / 追加ネットワークの効果の倍率')
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
|
parser.add_argument("--network_show_meta", action='store_true',
|
||||||
|
help='show metadata of network model / ネットワークモデルのメタデータを表示する')
|
||||||
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||||
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
|
|||||||
@@ -126,6 +126,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
def load_weights(self, file):
|
def load_weights(self, file):
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import load_file, safe_open
|
from safetensors.torch import load_file, safe_open
|
||||||
|
|||||||
Reference in New Issue
Block a user