mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add DyLoRA (experimental)
This commit is contained in:
448
networks/dylora.py
Normal file
448
networks/dylora.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# some codes are copied from:
|
||||
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
||||
|
||||
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
||||
# Changes made to the original code:
|
||||
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
||||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
"""
|
||||
|
||||
# NOTE: support dropout in future
|
||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
self.lora_dim = lora_dim
|
||||
self.unit = unit
|
||||
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
||||
|
||||
if self.is_conv2d and self.is_conv2d_3x3:
|
||||
kernel_size = org_module.kernel_size
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size)))
|
||||
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1)))
|
||||
else:
|
||||
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim)))
|
||||
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim)))
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_B)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
result = self.org_forward(x)
|
||||
|
||||
# specify the dynamic rank
|
||||
trainable_rank = random.randint(0, self.lora_dim - 1)
|
||||
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||||
|
||||
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||||
|
||||
# make lora_A
|
||||
if trainable_rank > 0:
|
||||
lora_A_nt1 = [self.lora_A[:trainable_rank].detach()]
|
||||
else:
|
||||
lora_A_nt1 = []
|
||||
|
||||
lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit]
|
||||
|
||||
if trainable_rank < self.lora_dim - self.unit:
|
||||
lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()]
|
||||
else:
|
||||
lora_A_nt2 = []
|
||||
|
||||
lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0)
|
||||
|
||||
# make lora_B
|
||||
if trainable_rank > 0:
|
||||
lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()]
|
||||
else:
|
||||
lora_B_nt1 = []
|
||||
|
||||
lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit]
|
||||
|
||||
if trainable_rank < self.lora_dim - self.unit:
|
||||
lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()]
|
||||
else:
|
||||
lora_B_nt2 = []
|
||||
|
||||
lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1)
|
||||
|
||||
# print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
|
||||
|
||||
# calculate with lora_A and lora_B
|
||||
if self.is_conv2d_3x3:
|
||||
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
||||
ab = torch.nn.functional.conv2d(ab, lora_B)
|
||||
else:
|
||||
ab = x
|
||||
if self.is_conv2d:
|
||||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2)
|
||||
|
||||
ab = torch.nn.functional.linear(ab, lora_A)
|
||||
ab = torch.nn.functional.linear(ab, lora_B)
|
||||
|
||||
if self.is_conv2d:
|
||||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:])
|
||||
|
||||
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||||
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||||
|
||||
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
|
||||
return result
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
# state dictを通常のLoRAと同じにする
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_A")
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight
|
||||
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_B")
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# 通常のLoRAと同じstate dictを読み込めるようにする
|
||||
state_dict = state_dict.copy()
|
||||
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight")
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||
state_dict[self.lora_name + ".lora_A"] = lora_A_weight
|
||||
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight")
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||
state_dict[self.lora_name + ".lora_B"] = lora_B_weight
|
||||
|
||||
super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
if network_dim is None:
|
||||
network_dim = 4 # default
|
||||
if network_alpha is None:
|
||||
network_alpha = 1.0
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = kwargs.get("conv_dim", None)
|
||||
conv_alpha = kwargs.get("conv_alpha", None)
|
||||
unit = kwargs.get("unit", None)
|
||||
if conv_dim is not None:
|
||||
conv_dim = int(conv_dim)
|
||||
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
unit = 1
|
||||
|
||||
network = DyLoRANetwork(
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=multiplier,
|
||||
lora_dim=network_dim,
|
||||
alpha=network_alpha,
|
||||
apply_to_conv=conv_dim is not None,
|
||||
unit=unit,
|
||||
varbose=True,
|
||||
)
|
||||
return network
|
||||
|
||||
|
||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||
if weights_sd is None:
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file, safe_open
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
# get dim/alpha mapping
|
||||
modules_dim = {}
|
||||
modules_alpha = {}
|
||||
for key, value in weights_sd.items():
|
||||
if "." not in key:
|
||||
continue
|
||||
|
||||
lora_name = key.split(".")[0]
|
||||
if "alpha" in key:
|
||||
modules_alpha[lora_name] = value
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
if key not in modules_alpha:
|
||||
modules_alpha = modules_dim[key]
|
||||
|
||||
module_class = DyLoRAModule
|
||||
|
||||
network = DyLoRANetwork(
|
||||
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||
)
|
||||
return network, weights_sd
|
||||
|
||||
|
||||
class DyLoRANetwork(torch.nn.Module):
|
||||
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder,
|
||||
unet,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
apply_to_conv=False,
|
||||
modules_dim=None,
|
||||
modules_alpha=None,
|
||||
unit=1,
|
||||
module_class=DyLoRAModule,
|
||||
varbose=False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
self.alpha = alpha
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
if modules_dim is not None:
|
||||
print(f"create LoRA network from weights")
|
||||
else:
|
||||
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||
if self.apply_to_conv:
|
||||
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
||||
|
||||
# create module instances
|
||||
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
loras = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
is_linear = child_module.__class__.__name__ == "Linear"
|
||||
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||
|
||||
if is_linear or is_conv2d:
|
||||
lora_name = prefix + "." + name + "." + child_name
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
dim = None
|
||||
alpha = None
|
||||
if modules_dim is not None:
|
||||
if lora_name in modules_dim:
|
||||
dim = modules_dim[lora_name]
|
||||
alpha = modules_alpha[lora_name]
|
||||
else:
|
||||
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
||||
dim = self.lora_dim
|
||||
alpha = self.alpha
|
||||
|
||||
if dim is None or dim == 0:
|
||||
continue
|
||||
|
||||
# dropout and fan_in_fan_out is default
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
|
||||
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
if modules_dim is not None or self.apply_to_conv:
|
||||
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.multiplier = self.multiplier
|
||||
|
||||
def load_weights(self, file):
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
weights_sd = load_file(file)
|
||||
else:
|
||||
weights_sd = torch.load(file, map_location="cpu")
|
||||
|
||||
info = self.load_state_dict(weights_sd, False)
|
||||
return info
|
||||
|
||||
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.apply_to()
|
||||
self.add_module(lora.lora_name, lora)
|
||||
|
||||
"""
|
||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||
apply_text_encoder = apply_unet = False
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||
apply_text_encoder = True
|
||||
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
||||
apply_unet = True
|
||||
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
else:
|
||||
self.text_encoder_loras = []
|
||||
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
else:
|
||||
self.unet_loras = []
|
||||
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
sd_for_lora = {}
|
||||
for key in weights_sd.keys():
|
||||
if key.startswith(lora.lora_name):
|
||||
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||
lora.merge_to(sd_for_lora, dtype, device)
|
||||
|
||||
print(f"weights are merged")
|
||||
"""
|
||||
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
def enumerate_params(loras):
|
||||
params = []
|
||||
for lora in loras:
|
||||
params.extend(lora.parameters())
|
||||
return params
|
||||
|
||||
if self.text_encoder_loras:
|
||||
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||
if text_encoder_lr is not None:
|
||||
param_data["lr"] = text_encoder_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
if self.unet_loras:
|
||||
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||
if unet_lr is not None:
|
||||
param_data["lr"] = unet_lr
|
||||
all_params.append(param_data)
|
||||
|
||||
return all_params
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
# not supported
|
||||
pass
|
||||
|
||||
def prepare_grad_etc(self, text_encoder, unet):
|
||||
self.requires_grad_(True)
|
||||
|
||||
def on_epoch_start(self, text_encoder, unet):
|
||||
self.train()
|
||||
|
||||
def get_trainable_params(self):
|
||||
return self.parameters()
|
||||
|
||||
def save_weights(self, file, dtype, metadata):
|
||||
if metadata is not None and len(metadata) == 0:
|
||||
metadata = None
|
||||
|
||||
state_dict = self.state_dict()
|
||||
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
v = state_dict[key]
|
||||
v = v.detach().clone().to("cpu").to(dtype)
|
||||
state_dict[key] = v
|
||||
|
||||
if os.path.splitext(file)[1] == ".safetensors":
|
||||
from safetensors.torch import save_file
|
||||
from library import train_util
|
||||
|
||||
# Precalculate model hashes to save time on indexing
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
save_file(state_dict, file, metadata)
|
||||
else:
|
||||
torch.save(state_dict, file)
|
||||
|
||||
# mask is a tensor with values from 0 to 1
|
||||
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||
pass
|
||||
|
||||
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||
pass
|
||||
261
networks/extract_lora_from_dylora.py
Normal file
261
networks/extract_lora_from_dylora.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||
# Thanks to cloneofsimo
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
|
||||
def load_state_dict(file_name):
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location="cpu")
|
||||
metadata = None
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, metadata):
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
# Indexing functions
|
||||
|
||||
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0) / s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_ratio(S, target):
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv / target
|
||||
index = int(torch.sum(S > min_sv).item())
|
||||
index = max(1, min(index, len(S) - 1))
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
# Calculate new rank
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method == "sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method == "sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
elif dynamic_method == "sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
new_alpha = float(scale * new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_rank = rank
|
||||
new_alpha = float(scale * new_rank)
|
||||
|
||||
# Calculate resize info
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro / s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank) / s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
def split_lora_model(lora_sd, unit):
|
||||
max_rank = 0
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
rank = value.size()[0]
|
||||
if rank > max_rank:
|
||||
max_rank = rank
|
||||
print(f"Max rank: {max_rank}")
|
||||
|
||||
rank = unit
|
||||
splitted_models = []
|
||||
while rank < max_rank:
|
||||
print(f"Splitting rank {rank}")
|
||||
new_sd = {}
|
||||
for key, value in lora_sd.items():
|
||||
if "lora_down" in key:
|
||||
new_sd[key] = value[:rank].contiguous()
|
||||
elif "lora_up" in key:
|
||||
new_sd[key] = value[:, :rank].contiguous()
|
||||
else:
|
||||
new_sd[key] = value # alpha and other parameters
|
||||
|
||||
splitted_models.append((new_sd, rank))
|
||||
rank += unit
|
||||
|
||||
return max_rank, splitted_models
|
||||
|
||||
|
||||
def split(args):
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model)
|
||||
|
||||
print("Splitting Model...")
|
||||
original_rank, splitted_models = split_lora_model(lora_sd, args.unit)
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
for state_dict, new_rank in splitted_models:
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
new_metadata = {}
|
||||
else:
|
||||
new_metadata = metadata.copy()
|
||||
|
||||
new_metadata["ss_training_comment"] = f"split from DyLoRA from {original_rank} to {new_rank}; {comment}"
|
||||
new_metadata["ss_network_dim"] = str(new_rank)
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
filename, ext = os.path.splitext(args.save_to)
|
||||
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||
|
||||
print(f"saving model to: {model_file_name}")
|
||||
save_to_file(model_file_name, state_dict, new_metadata)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
||||
parser.add_argument(
|
||||
"--save_to",
|
||||
type=str,
|
||||
default=None,
|
||||
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = setup_parser()
|
||||
|
||||
args = parser.parse_args()
|
||||
split(args)
|
||||
Reference in New Issue
Block a user