Merge pull request #174 from kohya-ss/dev

Dev
This commit is contained in:
Kohya S
2023-02-10 23:11:12 +09:00
committed by GitHub
6 changed files with 537 additions and 120 deletions

View File

@@ -8,7 +8,18 @@ Note: The LoRA models for SD 2.x is not supported too in Web UI.
- 10 Feb. 2023, 2023/2/10: - 10 Feb. 2023, 2023/2/10:
- Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade. - Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade.
- ``resize_lora.py`` keeps the metadata of the model. ``dimension is resized from ...`` is added to the top of ``ss_training_comment``.
- ``merge_lora.py`` supports models with different ``alpha``s. If there is a problem, old version is ``merge_lora_old.py``.
- ``svd_merge_lora.py`` is added. This script merges LoRA models with any rank (dim) and alpha, and approximate a new LoRA with svd for a specified rank (dim).
- Note: merging scripts erase the metadata currently.
- ``resize_images_to_resolution.py`` supports multibyte characters in filenames.
- pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。 - pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。
- ``resize_lora.py``がメタデータを保持するようになりました。 ``dimension is resized from ...`` という文字列が ``ss_training_comment`` の先頭に追加されます。
- ``merge_lora.py``がalphaが異なるモデルをサポートしました。 何か問題がありましたら旧バージョン ``merge_lora_old.py`` をお使いください。
- ``svd_merge_lora.py`` を追加しました。 複数の任意のdim (rank)、alphaのLoRAモデルをマージし、svdで任意dim(rank)のLoRAで近似します。
- 注:マージ系のスクリプトは現時点ではメタデータを消去しますのでご注意ください。
- ``resize_images_to_resolution.py``が日本語ファイル名をサポートしました。
- 9 Feb. 2023, 2023/2/9: - 9 Feb. 2023, 2023/2/9:
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource! - Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout). - ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).

View File

@@ -1,5 +1,5 @@
import math
import argparse import argparse
import os import os
import torch import torch
@@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
else: else:
# conv2d # conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {} base_alphas = {} # alpha for merged model
base_dims = {}
alpha = None merged_sd = {}
dim = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if 'alpha' in key:
lora_module_name = key[:key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if 'alpha' in key:
if key in merged_sd: continue
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else: lora_module_name = key[:key.rfind(".lora_")]
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key] base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else: else:
if key in merged_sd: merged_sd[key] = lora_sd[key] * scale
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" # set alpha to sd
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio for lora_module_name, alpha in base_alphas.items():
else: key = lora_module_name + ".alpha"
if "lora_down" in key: merged_sd[key] = torch.tensor(alpha)
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}") print("merged model")
if alpha is None: print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
alpha = dim
return merged_sd, dim, alpha return merged_sd
def merge(args): def merge(args):
@@ -152,7 +185,7 @@ def merge(args):
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae) args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)

179
networks/merge_lora_old.py Normal file
View File

@@ -0,0 +1,179 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
args = parser.parse_args()
merge(args)

View File

@@ -5,148 +5,169 @@
import argparse import argparse
import os import os
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm from tqdm import tqdm
from library import train_util, model_util
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if model_util.is_safetensors(file_name):
sd = load_file(file_name) sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location='cpu')
metadata = None
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
return sd
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if model_util.is_safetensors(file_name):
save_file(model, file_name) save_file(model, file_name, metadata)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
def resize_lora_model(model, new_rank, merge_dtype, save_dtype): def resize_lora_model(lora_sd, new_rank, save_dtype, device):
print("Loading Model...") network_alpha = None
lora_sd = load_state_dict(model, merge_dtype) network_dim = None
network_alpha = None CLAMP_QUANTILE = 0.99
network_dim = None
CLAMP_QUANTILE = 0.99 # Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
# Extract loaded lora dim and alpha scale = network_alpha/network_dim
for key, value in lora_sd.items(): new_alpha = float(scale*new_rank) # calculate new alpha from scale
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
new_alpha = float(scale*new_rank) # calculate new alpha from scale
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}") lora_down_weight = None
lora_up_weight = None
lora_down_weight = None o_lora_sd = lora_sd.copy()
lora_up_weight = None block_down_name = None
block_up_name = None
o_lora_sd = lora_sd.copy() print("resizing lora...")
block_down_name = None with torch.no_grad():
block_up_name = None for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
print("resizing lora...") weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) if (block_down_name == block_up_name) and weights_loaded:
if (block_down_name == block_up_name) and weights_loaded: conv2d = (len(lora_down_weight.size()) == 4)
conv2d = (len(lora_down_weight.size()) == 4) if conv2d:
lora_down_weight = lora_down_weight.squeeze()
if conv2d: lora_up_weight = lora_up_weight.squeeze()
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if args.device: if device:
org_device = lora_up_weight.device org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device) lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device) lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
U, S, Vh = torch.linalg.svd(full_weight_matrix) U, S, Vh = torch.linalg.svd(full_weight_matrix)
U = U[:, :new_rank] U = U[:, :new_rank]
S = S[:new_rank] S = S[:new_rank]
U = U @ torch.diag(S) U = U @ torch.diag(S)
Vh = Vh[:new_rank, :] Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()]) dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE) hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val low_val = -hi_val
U = U.clamp(low_val, hi_val) U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val) Vh = Vh.clamp(low_val, hi_val)
if conv2d:
U = U.unsqueeze(2).unsqueeze(3)
Vh = Vh.unsqueeze(2).unsqueeze(3)
if args.device:
U = U.to(org_device)
Vh = Vh.to(org_device)
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() if conv2d:
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() U = U.unsqueeze(2).unsqueeze(3)
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype) Vh = Vh.unsqueeze(2).unsqueeze(3)
block_down_name = None if args.device:
block_up_name = None U = U.to(org_device)
lora_down_weight = None Vh = Vh.to(org_device)
lora_up_weight = None
weights_loaded = False o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
print("resizing complete")
return o_lora_sd, network_dim, new_alpha
print("resizing complete")
return o_lora_sd
def resize(args): def resize(args):
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == 'float':
return torch.float return torch.float
if p == 'fp16': if p == 'fp16':
return torch.float16 return torch.float16
if p == 'bf16': if p == 'bf16':
return torch.bfloat16 return torch.bfloat16
return None return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None: if save_dtype is None:
save_dtype = merge_dtype save_dtype = merge_dtype
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype) print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print(f"saving model to: {args.save_to}") print("resizing rank...")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__': if __name__ == '__main__':

164
networks/svd_merge_lora.py Normal file
View File

@@ -0,0 +1,164 @@
import math
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
lora_module_name = key[:key.rfind(".lora_down")]
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
# W <- W + U * D
scale = (alpha / network_dim)
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
up_weight = U
down_weight = Vh
if conv2d:
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
return merged_lora_sd
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
merge(args)

View File

@@ -4,6 +4,8 @@ import cv2
import argparse import argparse
import shutil import shutil
import math import math
from PIL import Image
import numpy as np
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False): def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
@@ -35,7 +37,11 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
continue continue
# Load image # Load image
img = cv2.imread(os.path.join(src_img_folder, filename)) # img = cv2.imread(os.path.join(src_img_folder, filename))
image = Image.open(os.path.join(src_img_folder, filename))
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
base, _ = os.path.splitext(filename) base, _ = os.path.splitext(filename)
for max_resolution in max_resolutions: for max_resolution in max_resolutions:
@@ -72,7 +78,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg') new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
# Save resized image in dst_img_folder # Save resized image in dst_img_folder
cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100]) # cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
image = Image.fromarray(img)
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved" proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}") print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")