mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add DyLoRA (experimental)
This commit is contained in:
448
networks/dylora.py
Normal file
448
networks/dylora.py
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
# some codes are copied from:
|
||||||
|
# https://github.com/huawei-noah/KD-NLP/blob/main/DyLoRA/
|
||||||
|
|
||||||
|
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
|
||||||
|
# Changes made to the original code:
|
||||||
|
# 2022.08.20 - Integrate the DyLoRA layer for the LoRA Linear layer
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||||
|
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||||
|
# ------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class DyLoRAModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE: support dropout in future
|
||||||
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1, unit=1):
|
||||||
|
super().__init__()
|
||||||
|
self.lora_name = lora_name
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
self.unit = unit
|
||||||
|
assert self.lora_dim % self.unit == 0, "rank must be a multiple of unit"
|
||||||
|
|
||||||
|
if org_module.__class__.__name__ == "Conv2d":
|
||||||
|
in_dim = org_module.in_channels
|
||||||
|
out_dim = org_module.out_channels
|
||||||
|
else:
|
||||||
|
in_dim = org_module.in_features
|
||||||
|
out_dim = org_module.out_features
|
||||||
|
|
||||||
|
if type(alpha) == torch.Tensor:
|
||||||
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
|
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
|
self.scale = alpha / self.lora_dim
|
||||||
|
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||||
|
|
||||||
|
self.is_conv2d = org_module.__class__.__name__ == "Conv2d"
|
||||||
|
self.is_conv2d_3x3 = self.is_conv2d and org_module.kernel_size == (3, 3)
|
||||||
|
|
||||||
|
if self.is_conv2d and self.is_conv2d_3x3:
|
||||||
|
kernel_size = org_module.kernel_size
|
||||||
|
self.stride = org_module.stride
|
||||||
|
self.padding = org_module.padding
|
||||||
|
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size)))
|
||||||
|
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1)))
|
||||||
|
else:
|
||||||
|
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim)))
|
||||||
|
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim)))
|
||||||
|
|
||||||
|
# same as microsoft's
|
||||||
|
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
torch.nn.init.zeros_(self.lora_B)
|
||||||
|
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.org_module = org_module # remove in applying
|
||||||
|
|
||||||
|
def apply_to(self):
|
||||||
|
self.org_forward = self.org_module.forward
|
||||||
|
self.org_module.forward = self.forward
|
||||||
|
del self.org_module
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
result = self.org_forward(x)
|
||||||
|
|
||||||
|
# specify the dynamic rank
|
||||||
|
trainable_rank = random.randint(0, self.lora_dim - 1)
|
||||||
|
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||||||
|
|
||||||
|
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||||||
|
|
||||||
|
# make lora_A
|
||||||
|
if trainable_rank > 0:
|
||||||
|
lora_A_nt1 = [self.lora_A[:trainable_rank].detach()]
|
||||||
|
else:
|
||||||
|
lora_A_nt1 = []
|
||||||
|
|
||||||
|
lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit]
|
||||||
|
|
||||||
|
if trainable_rank < self.lora_dim - self.unit:
|
||||||
|
lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()]
|
||||||
|
else:
|
||||||
|
lora_A_nt2 = []
|
||||||
|
|
||||||
|
lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0)
|
||||||
|
|
||||||
|
# make lora_B
|
||||||
|
if trainable_rank > 0:
|
||||||
|
lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()]
|
||||||
|
else:
|
||||||
|
lora_B_nt1 = []
|
||||||
|
|
||||||
|
lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit]
|
||||||
|
|
||||||
|
if trainable_rank < self.lora_dim - self.unit:
|
||||||
|
lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()]
|
||||||
|
else:
|
||||||
|
lora_B_nt2 = []
|
||||||
|
|
||||||
|
lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1)
|
||||||
|
|
||||||
|
# print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
|
||||||
|
|
||||||
|
# calculate with lora_A and lora_B
|
||||||
|
if self.is_conv2d_3x3:
|
||||||
|
ab = torch.nn.functional.conv2d(x, lora_A, stride=self.stride, padding=self.padding)
|
||||||
|
ab = torch.nn.functional.conv2d(ab, lora_B)
|
||||||
|
else:
|
||||||
|
ab = x
|
||||||
|
if self.is_conv2d:
|
||||||
|
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2)
|
||||||
|
|
||||||
|
ab = torch.nn.functional.linear(ab, lora_A)
|
||||||
|
ab = torch.nn.functional.linear(ab, lora_B)
|
||||||
|
|
||||||
|
if self.is_conv2d:
|
||||||
|
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:])
|
||||||
|
|
||||||
|
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||||||
|
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||||||
|
|
||||||
|
# NOTE weightに加算してからlinear/conv2dを呼んだほうが速いかも
|
||||||
|
return result
|
||||||
|
|
||||||
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
|
# state dictを通常のLoRAと同じにする
|
||||||
|
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
|
lora_A_weight = state_dict.pop(self.lora_name + ".lora_A")
|
||||||
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
|
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight
|
||||||
|
|
||||||
|
lora_B_weight = state_dict.pop(self.lora_name + ".lora_B")
|
||||||
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
|
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||||||
|
state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
|
# 通常のLoRAと同じstate dictを読み込めるようにする
|
||||||
|
state_dict = state_dict.copy()
|
||||||
|
|
||||||
|
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight")
|
||||||
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
|
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||||
|
state_dict[self.lora_name + ".lora_A"] = lora_A_weight
|
||||||
|
|
||||||
|
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight")
|
||||||
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
|
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||||
|
state_dict[self.lora_name + ".lora_B"] = lora_B_weight
|
||||||
|
|
||||||
|
super().load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
|
if network_dim is None:
|
||||||
|
network_dim = 4 # default
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = 1.0
|
||||||
|
|
||||||
|
# extract dim/alpha for conv2d, and block dim
|
||||||
|
conv_dim = kwargs.get("conv_dim", None)
|
||||||
|
conv_alpha = kwargs.get("conv_alpha", None)
|
||||||
|
unit = kwargs.get("unit", None)
|
||||||
|
if conv_dim is not None:
|
||||||
|
conv_dim = int(conv_dim)
|
||||||
|
assert conv_dim == network_dim, "conv_dim must be same as network_dim"
|
||||||
|
if conv_alpha is None:
|
||||||
|
conv_alpha = 1.0
|
||||||
|
else:
|
||||||
|
conv_alpha = float(conv_alpha)
|
||||||
|
if unit is not None:
|
||||||
|
unit = int(unit)
|
||||||
|
else:
|
||||||
|
unit = 1
|
||||||
|
|
||||||
|
network = DyLoRANetwork(
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
multiplier=multiplier,
|
||||||
|
lora_dim=network_dim,
|
||||||
|
alpha=network_alpha,
|
||||||
|
apply_to_conv=conv_dim is not None,
|
||||||
|
unit=unit,
|
||||||
|
varbose=True,
|
||||||
|
)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||||
|
if weights_sd is None:
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
# get dim/alpha mapping
|
||||||
|
modules_dim = {}
|
||||||
|
modules_alpha = {}
|
||||||
|
for key, value in weights_sd.items():
|
||||||
|
if "." not in key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
if "alpha" in key:
|
||||||
|
modules_alpha[lora_name] = value
|
||||||
|
elif "lora_down" in key:
|
||||||
|
dim = value.size()[0]
|
||||||
|
modules_dim[lora_name] = dim
|
||||||
|
# print(lora_name, value.size(), dim)
|
||||||
|
|
||||||
|
# support old LoRA without alpha
|
||||||
|
for key in modules_dim.keys():
|
||||||
|
if key not in modules_alpha:
|
||||||
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
|
module_class = DyLoRAModule
|
||||||
|
|
||||||
|
network = DyLoRANetwork(
|
||||||
|
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||||
|
)
|
||||||
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
|
class DyLoRANetwork(torch.nn.Module):
|
||||||
|
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
|
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||||
|
LORA_PREFIX_UNET = "lora_unet"
|
||||||
|
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
multiplier=1.0,
|
||||||
|
lora_dim=4,
|
||||||
|
alpha=1,
|
||||||
|
apply_to_conv=False,
|
||||||
|
modules_dim=None,
|
||||||
|
modules_alpha=None,
|
||||||
|
unit=1,
|
||||||
|
module_class=DyLoRAModule,
|
||||||
|
varbose=False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
self.lora_dim = lora_dim
|
||||||
|
self.alpha = alpha
|
||||||
|
self.apply_to_conv = apply_to_conv
|
||||||
|
|
||||||
|
if modules_dim is not None:
|
||||||
|
print(f"create LoRA network from weights")
|
||||||
|
else:
|
||||||
|
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
|
||||||
|
if self.apply_to_conv:
|
||||||
|
print(f"apply LoRA to Conv2d with kernel size (3,3).")
|
||||||
|
|
||||||
|
# create module instances
|
||||||
|
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
|
||||||
|
prefix = DyLoRANetwork.LORA_PREFIX_UNET if is_unet else DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||||
|
loras = []
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
is_linear = child_module.__class__.__name__ == "Linear"
|
||||||
|
is_conv2d = child_module.__class__.__name__ == "Conv2d"
|
||||||
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
if is_linear or is_conv2d:
|
||||||
|
lora_name = prefix + "." + name + "." + child_name
|
||||||
|
lora_name = lora_name.replace(".", "_")
|
||||||
|
|
||||||
|
dim = None
|
||||||
|
alpha = None
|
||||||
|
if modules_dim is not None:
|
||||||
|
if lora_name in modules_dim:
|
||||||
|
dim = modules_dim[lora_name]
|
||||||
|
alpha = modules_alpha[lora_name]
|
||||||
|
else:
|
||||||
|
if is_linear or is_conv2d_1x1 or apply_to_conv:
|
||||||
|
dim = self.lora_dim
|
||||||
|
alpha = self.alpha
|
||||||
|
|
||||||
|
if dim is None or dim == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# dropout and fan_in_fan_out is default
|
||||||
|
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||||
|
loras.append(lora)
|
||||||
|
return loras
|
||||||
|
|
||||||
|
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
|
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||||
|
|
||||||
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
|
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||||
|
if modules_dim is not None or self.apply_to_conv:
|
||||||
|
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
|
self.unet_loras = create_modules(True, unet, target_modules)
|
||||||
|
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
info = self.load_state_dict(weights_sd, False)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
|
if apply_text_encoder:
|
||||||
|
print("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
print("enable LoRA for U-Net")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.apply_to()
|
||||||
|
self.add_module(lora.lora_name, lora)
|
||||||
|
|
||||||
|
"""
|
||||||
|
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||||
|
apply_text_encoder = apply_unet = False
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(DyLoRANetwork.LORA_PREFIX_TEXT_ENCODER):
|
||||||
|
apply_text_encoder = True
|
||||||
|
elif key.startswith(DyLoRANetwork.LORA_PREFIX_UNET):
|
||||||
|
apply_unet = True
|
||||||
|
|
||||||
|
if apply_text_encoder:
|
||||||
|
print("enable LoRA for text encoder")
|
||||||
|
else:
|
||||||
|
self.text_encoder_loras = []
|
||||||
|
|
||||||
|
if apply_unet:
|
||||||
|
print("enable LoRA for U-Net")
|
||||||
|
else:
|
||||||
|
self.unet_loras = []
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(lora.lora_name):
|
||||||
|
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
|
||||||
|
lora.merge_to(sd_for_lora, dtype, device)
|
||||||
|
|
||||||
|
print(f"weights are merged")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
|
self.requires_grad_(True)
|
||||||
|
all_params = []
|
||||||
|
|
||||||
|
def enumerate_params(loras):
|
||||||
|
params = []
|
||||||
|
for lora in loras:
|
||||||
|
params.extend(lora.parameters())
|
||||||
|
return params
|
||||||
|
|
||||||
|
if self.text_encoder_loras:
|
||||||
|
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||||
|
if text_encoder_lr is not None:
|
||||||
|
param_data["lr"] = text_encoder_lr
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
|
if self.unet_loras:
|
||||||
|
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||||
|
if unet_lr is not None:
|
||||||
|
param_data["lr"] = unet_lr
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
|
return all_params
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
# not supported
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
def on_epoch_start(self, text_encoder, unet):
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
def get_trainable_params(self):
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
if metadata is not None and len(metadata) == 0:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
|
# Precalculate model hashes to save time on indexing
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
|
save_file(state_dict, file, metadata)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
# mask is a tensor with values from 0 to 1
|
||||||
|
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||||
|
pass
|
||||||
261
networks/extract_lora_from_dylora.py
Normal file
261
networks/extract_lora_from_dylora.py
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
# Convert LoRA to different rank approximation (should only be used to go to lower rank)
|
||||||
|
# This code is based off the extract_lora_from_models.py file which is based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
|
||||||
|
# Thanks to cloneofsimo
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
|
from tqdm import tqdm
|
||||||
|
from library import train_util, model_util
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def load_state_dict(file_name):
|
||||||
|
if model_util.is_safetensors(file_name):
|
||||||
|
sd = load_file(file_name)
|
||||||
|
with safe_open(file_name, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
else:
|
||||||
|
sd = torch.load(file_name, map_location="cpu")
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
return sd, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_file(file_name, model, metadata):
|
||||||
|
if model_util.is_safetensors(file_name):
|
||||||
|
save_file(model, file_name, metadata)
|
||||||
|
else:
|
||||||
|
torch.save(model, file_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Indexing functions
|
||||||
|
|
||||||
|
|
||||||
|
def index_sv_cumulative(S, target):
|
||||||
|
original_sum = float(torch.sum(S))
|
||||||
|
cumulative_sums = torch.cumsum(S, dim=0) / original_sum
|
||||||
|
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||||
|
index = max(1, min(index, len(S) - 1))
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def index_sv_fro(S, target):
|
||||||
|
S_squared = S.pow(2)
|
||||||
|
s_fro_sq = float(torch.sum(S_squared))
|
||||||
|
sum_S_squared = torch.cumsum(S_squared, dim=0) / s_fro_sq
|
||||||
|
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||||
|
index = max(1, min(index, len(S) - 1))
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
def index_sv_ratio(S, target):
|
||||||
|
max_sv = S[0]
|
||||||
|
min_sv = max_sv / target
|
||||||
|
index = int(torch.sum(S > min_sv).item())
|
||||||
|
index = max(1, min(index, len(S) - 1))
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||||
|
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||||
|
out_size, in_size, kernel_size, _ = weight.size()
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||||
|
|
||||||
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||||
|
lora_rank = param_dict["new_rank"]
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||||
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||||
|
out_size, in_size = weight.size()
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||||
|
|
||||||
|
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||||
|
lora_rank = param_dict["new_rank"]
|
||||||
|
|
||||||
|
U = U[:, :lora_rank]
|
||||||
|
S = S[:lora_rank]
|
||||||
|
U = U @ torch.diag(S)
|
||||||
|
Vh = Vh[:lora_rank, :]
|
||||||
|
|
||||||
|
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||||
|
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||||
|
del U, S, Vh, weight
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def merge_conv(lora_down, lora_up, device):
|
||||||
|
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||||
|
out_size, out_rank, _, _ = lora_up.shape
|
||||||
|
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||||
|
|
||||||
|
lora_down = lora_down.to(device)
|
||||||
|
lora_up = lora_up.to(device)
|
||||||
|
|
||||||
|
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||||
|
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||||
|
del lora_up, lora_down
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
def merge_linear(lora_down, lora_up, device):
|
||||||
|
in_rank, in_size = lora_down.shape
|
||||||
|
out_size, out_rank = lora_up.shape
|
||||||
|
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||||
|
|
||||||
|
lora_down = lora_down.to(device)
|
||||||
|
lora_up = lora_up.to(device)
|
||||||
|
|
||||||
|
weight = lora_up @ lora_down
|
||||||
|
del lora_up, lora_down
|
||||||
|
return weight
|
||||||
|
|
||||||
|
|
||||||
|
# Calculate new rank
|
||||||
|
|
||||||
|
|
||||||
|
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||||
|
param_dict = {}
|
||||||
|
|
||||||
|
if dynamic_method == "sv_ratio":
|
||||||
|
# Calculate new dim and alpha based off ratio
|
||||||
|
new_rank = index_sv_ratio(S, dynamic_param) + 1
|
||||||
|
new_alpha = float(scale * new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method == "sv_cumulative":
|
||||||
|
# Calculate new dim and alpha based off cumulative sum
|
||||||
|
new_rank = index_sv_cumulative(S, dynamic_param) + 1
|
||||||
|
new_alpha = float(scale * new_rank)
|
||||||
|
|
||||||
|
elif dynamic_method == "sv_fro":
|
||||||
|
# Calculate new dim and alpha based off sqrt sum of squares
|
||||||
|
new_rank = index_sv_fro(S, dynamic_param) + 1
|
||||||
|
new_alpha = float(scale * new_rank)
|
||||||
|
else:
|
||||||
|
new_rank = rank
|
||||||
|
new_alpha = float(scale * new_rank)
|
||||||
|
|
||||||
|
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||||
|
new_rank = 1
|
||||||
|
new_alpha = float(scale * new_rank)
|
||||||
|
elif new_rank > rank: # cap max rank at rank
|
||||||
|
new_rank = rank
|
||||||
|
new_alpha = float(scale * new_rank)
|
||||||
|
|
||||||
|
# Calculate resize info
|
||||||
|
s_sum = torch.sum(torch.abs(S))
|
||||||
|
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||||
|
|
||||||
|
S_squared = S.pow(2)
|
||||||
|
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||||
|
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||||
|
fro_percent = float(s_red_fro / s_fro)
|
||||||
|
|
||||||
|
param_dict["new_rank"] = new_rank
|
||||||
|
param_dict["new_alpha"] = new_alpha
|
||||||
|
param_dict["sum_retained"] = (s_rank) / s_sum
|
||||||
|
param_dict["fro_retained"] = fro_percent
|
||||||
|
param_dict["max_ratio"] = S[0] / S[new_rank - 1]
|
||||||
|
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
|
||||||
|
def split_lora_model(lora_sd, unit):
|
||||||
|
max_rank = 0
|
||||||
|
|
||||||
|
# Extract loaded lora dim and alpha
|
||||||
|
for key, value in lora_sd.items():
|
||||||
|
if "lora_down" in key:
|
||||||
|
rank = value.size()[0]
|
||||||
|
if rank > max_rank:
|
||||||
|
max_rank = rank
|
||||||
|
print(f"Max rank: {max_rank}")
|
||||||
|
|
||||||
|
rank = unit
|
||||||
|
splitted_models = []
|
||||||
|
while rank < max_rank:
|
||||||
|
print(f"Splitting rank {rank}")
|
||||||
|
new_sd = {}
|
||||||
|
for key, value in lora_sd.items():
|
||||||
|
if "lora_down" in key:
|
||||||
|
new_sd[key] = value[:rank].contiguous()
|
||||||
|
elif "lora_up" in key:
|
||||||
|
new_sd[key] = value[:, :rank].contiguous()
|
||||||
|
else:
|
||||||
|
new_sd[key] = value # alpha and other parameters
|
||||||
|
|
||||||
|
splitted_models.append((new_sd, rank))
|
||||||
|
rank += unit
|
||||||
|
|
||||||
|
return max_rank, splitted_models
|
||||||
|
|
||||||
|
|
||||||
|
def split(args):
|
||||||
|
print("loading Model...")
|
||||||
|
lora_sd, metadata = load_state_dict(args.model)
|
||||||
|
|
||||||
|
print("Splitting Model...")
|
||||||
|
original_rank, splitted_models = split_lora_model(lora_sd, args.unit)
|
||||||
|
|
||||||
|
comment = metadata.get("ss_training_comment", "")
|
||||||
|
for state_dict, new_rank in splitted_models:
|
||||||
|
# update metadata
|
||||||
|
if metadata is None:
|
||||||
|
new_metadata = {}
|
||||||
|
else:
|
||||||
|
new_metadata = metadata.copy()
|
||||||
|
|
||||||
|
new_metadata["ss_training_comment"] = f"split from DyLoRA from {original_rank} to {new_rank}; {comment}"
|
||||||
|
new_metadata["ss_network_dim"] = str(new_rank)
|
||||||
|
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
|
filename, ext = os.path.splitext(args.save_to)
|
||||||
|
model_file_name = filename + f"-{new_rank:04d}{ext}"
|
||||||
|
|
||||||
|
print(f"saving model to: {model_file_name}")
|
||||||
|
save_to_file(model_file_name, state_dict, new_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--unit", type=int, default=None, help="size of rank to split into / rankを分割するサイズ")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save_to",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="destination base file name: ckpt or safetensors file / 保存先のファイル名のbase、ckptまたはsafetensors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="DyLoRA model to resize at to new rank: ckpt or safetensors file / 読み込むDyLoRAモデル、ckptまたはsafetensors",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
split(args)
|
||||||
@@ -221,7 +221,9 @@ def train(args):
|
|||||||
try:
|
try:
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
print("Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)")
|
print(
|
||||||
|
"Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)"
|
||||||
|
)
|
||||||
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr)
|
||||||
|
|
||||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params)
|
||||||
@@ -541,6 +543,12 @@ def train(args):
|
|||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
del train_dataset_group
|
del train_dataset_group
|
||||||
|
|
||||||
|
# if hasattr(network, "on_step_start"):
|
||||||
|
# on_step_start = network.on_step_start
|
||||||
|
# else:
|
||||||
|
# on_step_start = lambda *args, **kwargs: None
|
||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
@@ -553,6 +561,8 @@ def train(args):
|
|||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
|
# on_step_start(text_encoder, unet)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
@@ -565,7 +575,8 @@ def train(args):
|
|||||||
with torch.set_grad_enabled(train_text_encoder):
|
with torch.set_grad_enabled(train_text_encoder):
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
|
encoder_hidden_states = get_weighted_text_embeddings(
|
||||||
|
tokenizer,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
batch["captions"],
|
batch["captions"],
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
|
|||||||
Reference in New Issue
Block a user