mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge branch 'main' of https://github.com/kohya-ss/sd-scripts
This commit is contained in:
@@ -1,38 +1,3 @@
|
||||
# txt2img with Diffusers: supports SD checkpoints, EulerScheduler, clip-skip, 225 tokens, Hypernetwork etc...
|
||||
|
||||
# v2: CLIP guided Stable Diffusion, Image guided Stable Diffusion, highres. fix
|
||||
# v3: Add dpmsolver/dpmsolver++, add VAE loading, add upscale, add 'bf16', fix the issue network_mul is not working
|
||||
# v4: SD2.0 support (new U-Net/text encoder/tokenizer), simplify by DiffUsers 0.9.0, no_preview in interactive mode
|
||||
# v5: fix clip_sample=True for scheduler, add VGG guidance
|
||||
# v6: refactor to use model util, load VAE without vae folder, support safe tensors
|
||||
# v7: add use_original_file_name and iter_same_seed option, change vgg16 guide input image size,
|
||||
# Diffusers 0.10.0 (support new schedulers (dpm_2, dpm_2_a, heun, dpmsingle), supports all scheduler in v-prediction)
|
||||
# v8: accept wildcard for ckpt name (when only one file is matched), fix a bug app crushes because PIL image doesn't have filename attr sometimes,
|
||||
# v9: sort file names, fix an issue in img2img when prompt from metadata with images_per_prompt>1
|
||||
# v10: fix app crashes when different image size in prompts
|
||||
|
||||
# Copyright 2022 kohya_ss @kohya_ss
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# license of included scripts:
|
||||
|
||||
# FlashAttention: based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
|
||||
# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
|
||||
|
||||
# Diffusers (model conversion, CLIP guided stable diffusion, schedulers etc.):
|
||||
# ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
|
||||
|
||||
"""
|
||||
VGG(
|
||||
(features): Sequential(
|
||||
@@ -517,7 +482,7 @@ class PipelineLike():
|
||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
@@ -1982,26 +1947,42 @@ def main(args):
|
||||
vgg16_model.to(dtype).to(device)
|
||||
|
||||
# networkを組み込む
|
||||
if args.network_module is not None:
|
||||
# assert not args.diffusers_xformers, "cannot use network with diffusers_xformers / diffusers_xformers指定時はnetworkは利用できません"
|
||||
if args.network_module:
|
||||
networks = []
|
||||
for i, network_module in enumerate(args.network_module):
|
||||
print("import network module:", network_module)
|
||||
imported_module = importlib.import_module(network_module)
|
||||
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
|
||||
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
|
||||
|
||||
network = network_module.create_network(args.network_mul, args.network_dim, vae,text_encoder, unet) # , **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
net_kwargs = {}
|
||||
if args.network_args and i < len(args.network_args):
|
||||
network_args = args.network_args[i]
|
||||
# TODO escape special chars
|
||||
network_args = network_args.split(";")
|
||||
for net_arg in network_args:
|
||||
key, value = net_arg.split("=")
|
||||
net_kwargs[key] = value
|
||||
|
||||
print("load network weights from:", args.network_weights)
|
||||
network.load_weights(args.network_weights)
|
||||
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
|
||||
if network is None:
|
||||
return
|
||||
|
||||
network.apply_to(text_encoder, unet)
|
||||
if args.network_weights and i < len(args.network_weights):
|
||||
network_weight = args.network_weights[i]
|
||||
print("load network weights from:", network_weight)
|
||||
network.load_weights(network_weight)
|
||||
|
||||
if args.opt_channels_last:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
network.to(dtype).to(device)
|
||||
network.apply_to(text_encoder, unet)
|
||||
|
||||
if args.opt_channels_last:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
network.to(dtype).to(device)
|
||||
|
||||
networks.append(network)
|
||||
else:
|
||||
network = None
|
||||
networks = []
|
||||
|
||||
if args.opt_channels_last:
|
||||
print(f"set optimizing: channels last")
|
||||
@@ -2010,8 +1991,9 @@ def main(args):
|
||||
unet.to(memory_format=torch.channels_last)
|
||||
if clip_model is not None:
|
||||
clip_model.to(memory_format=torch.channels_last)
|
||||
if network is not None:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
if networks:
|
||||
for network in networks:
|
||||
network.to(memory_format=torch.channels_last)
|
||||
if vgg16_model is not None:
|
||||
vgg16_model.to(memory_format=torch.channels_last)
|
||||
|
||||
@@ -2053,7 +2035,7 @@ def main(args):
|
||||
print(f"convert image to RGB from {image.mode}: {p}")
|
||||
image = image.convert("RGB")
|
||||
images.append(image)
|
||||
|
||||
|
||||
return images
|
||||
|
||||
def resize_images(imgs, size):
|
||||
@@ -2481,19 +2463,24 @@ if __name__ == '__main__':
|
||||
# help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
|
||||
parser.add_argument("--seed", type=int, default=None,
|
||||
help="seed, or seed of seeds in multiple generation / 1枚生成時のseed、または複数枚生成時の乱数seedを決めるためのseed")
|
||||
parser.add_argument("--iter_same_seed", action='store_true', help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
|
||||
parser.add_argument("--iter_same_seed", action='store_true',
|
||||
help='use same seed for all prompts in iteration if no seed specified / 乱数seedの指定がないとき繰り返し内はすべて同じseedを使う(プロンプト間の差異の比較用)')
|
||||
parser.add_argument("--fp16", action='store_true', help='use fp16 / fp16を指定し省メモリ化する')
|
||||
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
|
||||
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
|
||||
parser.add_argument("--diffusers_xformers", action='store_true',
|
||||
help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
|
||||
parser.add_argument("--opt_channels_last", action='store_true',
|
||||
help='set channels last option to model / モデルにchannels lastを指定し最適化する')
|
||||
parser.add_argument("--network_module", type=str, default=None, help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
parser.add_argument("--network_weights", type=str, default=None, help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=1.0, help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_dim", type=int, default=None,
|
||||
help='set channels last option to model / モデルにchannles lastを指定し最適化する')
|
||||
parser.add_argument("--network_module", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
|
||||
parser.add_argument("--network_weights", type=str, default=None, nargs='*',
|
||||
help='Hypernetwork weights to load / Hypernetworkの重み')
|
||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
|
||||
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
|
||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||
|
||||
158
networks/extract_lora_from_models.py
Normal file
158
networks/extract_lora_from_models.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# extract approximating LoRA by svd from two SD models
|
||||
# The code is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo!
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
MIN_DIFF = 1e-6
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def svd(args):
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
|
||||
print(f"loading SD model : {args.model_org}")
|
||||
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_org)
|
||||
print(f"loading SD model : {args.model_tuned}")
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||
|
||||
# create LoRA network to extract weights
|
||||
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
|
||||
# get diffs
|
||||
diffs = {}
|
||||
text_encoder_different = False
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.text_encoder_loras, lora_network_t.text_encoder_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
|
||||
# Text Encoder might be same
|
||||
if torch.max(torch.abs(diff)) > MIN_DIFF:
|
||||
text_encoder_different = True
|
||||
|
||||
diff = diff.float()
|
||||
diffs[lora_name] = diff
|
||||
|
||||
if not text_encoder_different:
|
||||
print("Text encoder is same. Extract U-Net only.")
|
||||
lora_network_o.text_encoder_loras = []
|
||||
diffs = {}
|
||||
|
||||
for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
|
||||
lora_name = lora_o.lora_name
|
||||
module_o = lora_o.org_module
|
||||
module_t = lora_t.org_module
|
||||
diff = module_t.weight - module_o.weight
|
||||
diff = diff.float()
|
||||
|
||||
if args.device:
|
||||
diff = diff.to(args.device)
|
||||
|
||||
diffs[lora_name] = diff
|
||||
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
rank = args.dim
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
if conv2d:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :rank]
|
||||
S = S[:rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# make state dict for LoRA
|
||||
lora_network_o.apply_to(text_encoder_o, unet_o, text_encoder_different, True) # to make state dict
|
||||
lora_sd = lora_network_o.state_dict()
|
||||
print(f"LoRA has {len(lora_sd)} weights.")
|
||||
|
||||
for key in list(lora_sd.keys()):
|
||||
lora_name = key.split('.')[0]
|
||||
i = 0 if "lora_up" in key else 1
|
||||
|
||||
weights = lora_weights[lora_name][i]
|
||||
# print(key, i, weights.size(), lora_sd[key].size())
|
||||
if len(lora_sd[key].size()) == 4:
|
||||
weights = weights.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
assert weights.size() == lora_sd[key].size()
|
||||
lora_sd[key] = weights
|
||||
|
||||
# load state dict to LoRA and save it
|
||||
info = lora_network_o.load_state_dict(lora_sd)
|
||||
print(f"Loading extracted LoRA weights: {info}")
|
||||
|
||||
dir_name = os.path.dirname(args.save_to)
|
||||
if dir_name and not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
lora_network_o.save_weights(args.save_to, save_dtype)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はfloat")
|
||||
parser.add_argument("--model_org", type=str, default=None,
|
||||
help="Stable Diffusion original model: ckpt or safetensors file / 元モデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--model_tuned", type=str, default=None,
|
||||
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル(生成されるLoRAは元→派生の差分になります)、ckptまたはsafetensors")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数(デフォルト4)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う")
|
||||
|
||||
args = parser.parse_args()
|
||||
svd(args)
|
||||
@@ -10,9 +10,7 @@
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルに、このリポジトリ内のスクリプトであらかじめマージしておく必要があります。マージ後のモデルファイルはLoRAの学習結果が反映されたものになります。
|
||||
|
||||
なお当リポジトリ内の画像生成スクリプトで生成する場合はマージ不要です。
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
|
||||
## 学習方法
|
||||
|
||||
@@ -24,7 +22,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正
|
||||
|
||||
### DreamBoothの手法を用いる場合
|
||||
|
||||
note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594) を参照してデータを用意してください。
|
||||
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
|
||||
|
||||
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。
|
||||
|
||||
@@ -110,7 +108,7 @@ python networks\merge_lora.py --sd_model ..\model\model.ckpt
|
||||
|
||||
### 複数のLoRAのモデルをマージする
|
||||
|
||||
結局のところSDモデルにマージしないと推論できないのであまり使い道はないかもしれません。ただ、複数のLoRAモデルをひとつずつSDモデルにマージしていく場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
|
||||
複数のLoRAモデルをひとつずつSDモデルに適用する場合と、複数のLoRAモデルをマージしてからSDモデルにマージする場合とは、計算順序の関連で微妙に異なる結果になります。
|
||||
|
||||
たとえば以下のようなコマンドラインになります。
|
||||
|
||||
@@ -144,6 +142,40 @@ gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim
|
||||
|
||||
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。
|
||||
|
||||
## 二つのモデルの差分からLoRAモデルを作成する
|
||||
|
||||
[こちらのディスカッション](https://github.com/cloneofsimo/lora/discussions/56)を参考に実装したものです。数式はそのまま使わせていただきました(よく理解していませんが近似には特異値分解を用いるようです)。
|
||||
|
||||
二つのモデル(たとえばfine tuningの元モデルとfine tuning後のモデル)の差分を、LoRAで近似します。
|
||||
|
||||
### スクリプトの実行方法
|
||||
|
||||
以下のように指定してください。
|
||||
```
|
||||
python networks\extract_lora_from_models.py --model_org base-model.ckpt
|
||||
--model_tuned fine-tuned-model.ckpt
|
||||
--save_to lora-weights.safetensors --dim 4
|
||||
```
|
||||
|
||||
--model_orgオプションに元のStable Diffusionモデルを指定します。作成したLoRAモデルを適用する場合は、このモデルを指定して適用することになります。.ckptまたは.safetensorsが指定できます。
|
||||
|
||||
--model_tunedオプションに差分を抽出する対象のStable Diffusionモデルを指定します。たとえばfine tuningやDreamBooth後のモデルを指定します。.ckptまたは.safetensorsが指定できます。
|
||||
|
||||
--save_toにLoRAモデルの保存先を指定します。--dimにLoRAの次元数を指定します。
|
||||
|
||||
生成されたLoRAモデルは、学習したLoRAモデルと同様に使用できます。
|
||||
|
||||
Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRAとなります。
|
||||
|
||||
### その他のオプション
|
||||
|
||||
- --v2
|
||||
- v2.xのStable Diffusionモデルを使う場合に指定してください。
|
||||
- --device
|
||||
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
|
||||
- --save_precision
|
||||
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
||||
|
||||
## 追加情報
|
||||
|
||||
### cloneofsimo氏のリポジトリとの違い
|
||||
|
||||
Reference in New Issue
Block a user