Update initialization, add lora_util, add tests

This commit is contained in:
rockerBOO
2025-03-25 18:22:07 -04:00
parent 0bad5ae9f1
commit 0ad3b3c2bd
5 changed files with 961 additions and 134 deletions

View File

@@ -18,7 +18,7 @@ from torch import Tensor
from tqdm import tqdm
import re
from library.utils import setup_logging
from library.train_util import initialize_lora, initialize_pissa, initialize_urae
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
setup_logging()
import logging
@@ -38,7 +38,7 @@ class LoRAModule(torch.nn.Module):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
org_module: torch.nn.Linear,
multiplier=1.0,
lora_dim=4,
alpha=1,
@@ -56,68 +56,32 @@ class LoRAModule(torch.nn.Module):
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
if isinstance(alpha, torch.Tensor):
alpha = alpha.detach().float().item() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
self.split_dims = split_dims
self.initialize = initialize
if split_dims is None:
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if initialize == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
else:
initialize_lora(self.lora_down, self.lora_up)
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
else:
# conv2d not supported
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
# print(f"split_dims: {split_dims}")
self.lora_down = torch.nn.ModuleList(
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
)
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
if initialize == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
else:
initialize_lora(lora_down, lora_up)
with torch.autocast(org_module.weight.device.type), torch.no_grad():
self.initialize_weights(org_module)
# same as microsoft's
self.multiplier = multiplier
@@ -126,12 +90,44 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def initialize_weights(self, org_module: torch.nn.Module):
if self.split_dims is None:
if self.initialize == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif self.initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
else:
initialize_lora(self.lora_down, self.lora_up)
else:
assert isinstance(self.lora_down, torch.nn.ModuleList)
assert isinstance(self.lora_up, torch.nn.ModuleList)
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
if self.initialize == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif self.initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
else:
initialize_lora(lora_down, lora_up)
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
def forward(self, x) -> torch.Tensor:
org_forwarded = self.org_forward(x)
# module dropout
@@ -175,10 +171,6 @@ class LoRAModule(torch.nn.Module):
if self.rank_dropout is not None and self.training:
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
for i in range(len(lxs)):
if len(lx.size()) == 3:
masks[i] = masks[i].unsqueeze(1)
elif len(lx.size()) == 4:
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
lxs[i] = lxs[i] * masks[i]
# scaling for rank dropout: treat as if the rank is changed
@@ -765,6 +757,7 @@ class LoRANetwork(torch.nn.Module):
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
text_encoders = text_encoders if isinstance(text_encoders, list) else [text_encoders]
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
@@ -1103,8 +1096,7 @@ class LoRANetwork(torch.nn.Module):
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
# One approach is to do SVD on delta_w
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# Take the top 2*r singular values (as suggested in the paper)
rank = rank * 2
@@ -1124,7 +1116,8 @@ class LoRANetwork(torch.nn.Module):
lora_down_key = f"{lora.lora_name}.lora_down.weight"
lora_up = state_dict[lora_up_key]
lora_down = state_dict[lora_down_key]
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
with torch.autocast("cuda"):
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
state_dict[lora_up_key] = up.detach()
state_dict[lora_down_key] = down.detach()
progress.update(1)