mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Update initialization, add lora_util, add tests
This commit is contained in:
@@ -18,7 +18,7 @@ from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.train_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -38,7 +38,7 @@ class LoRAModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
org_module: torch.nn.Linear,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
@@ -56,68 +56,32 @@ class LoRAModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.detach().float().item() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
self.split_dims = split_dims
|
||||
self.initialize = initialize
|
||||
|
||||
if split_dims is None:
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(self.lora_down, self.lora_up)
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
else:
|
||||
# conv2d not supported
|
||||
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
||||
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
|
||||
# print(f"split_dims: {split_dims}")
|
||||
self.lora_down = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||
)
|
||||
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(lora_down, lora_up)
|
||||
|
||||
with torch.autocast(org_module.weight.device.type), torch.no_grad():
|
||||
self.initialize_weights(org_module)
|
||||
|
||||
# same as microsoft's
|
||||
self.multiplier = multiplier
|
||||
@@ -126,12 +90,44 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
|
||||
def initialize_weights(self, org_module: torch.nn.Module):
|
||||
if self.split_dims is None:
|
||||
if self.initialize == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif self.initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(self.lora_down, self.lora_up)
|
||||
else:
|
||||
assert isinstance(self.lora_down, torch.nn.ModuleList)
|
||||
assert isinstance(self.lora_up, torch.nn.ModuleList)
|
||||
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
|
||||
if self.initialize == "urae":
|
||||
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
elif self.initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(lora_down, lora_up)
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
@@ -175,10 +171,6 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.rank_dropout is not None and self.training:
|
||||
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
|
||||
for i in range(len(lxs)):
|
||||
if len(lx.size()) == 3:
|
||||
masks[i] = masks[i].unsqueeze(1)
|
||||
elif len(lx.size()) == 4:
|
||||
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
|
||||
lxs[i] = lxs[i] * masks[i]
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
@@ -765,6 +757,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||
skipped_te = []
|
||||
text_encoders = text_encoders if isinstance(text_encoders, list) else [text_encoders]
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
index = i
|
||||
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
@@ -1103,8 +1096,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
# One approach is to do SVD on delta_w
|
||||
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
@@ -1124,7 +1116,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
||||
lora_up = state_dict[lora_up_key]
|
||||
lora_down = state_dict[lora_down_key]
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
|
||||
with torch.autocast("cuda"):
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
|
||||
state_dict[lora_up_key] = up.detach()
|
||||
state_dict[lora_down_key] = down.detach()
|
||||
progress.update(1)
|
||||
|
||||
Reference in New Issue
Block a user