mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix parameters are not freezed
This commit is contained in:
@@ -50,15 +50,17 @@ class DyLoRAModule(torch.nn.Module):
|
||||
kernel_size = org_module.kernel_size
|
||||
self.stride = org_module.stride
|
||||
self.padding = org_module.padding
|
||||
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size)))
|
||||
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1)))
|
||||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
||||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
||||
else:
|
||||
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim)))
|
||||
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim)))
|
||||
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
||||
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
||||
|
||||
# same as microsoft's
|
||||
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_B)
|
||||
for lora in self.lora_A:
|
||||
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
||||
for lora in self.lora_B:
|
||||
torch.nn.init.zeros_(lora)
|
||||
|
||||
self.multiplier = multiplier
|
||||
self.org_module = org_module # remove in applying
|
||||
@@ -76,38 +78,18 @@ class DyLoRAModule(torch.nn.Module):
|
||||
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||||
|
||||
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||||
for i in range(0, trainable_rank):
|
||||
self.lora_A[i].requires_grad = False
|
||||
self.lora_B[i].requires_grad = False
|
||||
for i in range(trainable_rank, trainable_rank + self.unit):
|
||||
self.lora_A[i].requires_grad = True
|
||||
self.lora_B[i].requires_grad = True
|
||||
for i in range(trainable_rank + self.unit, self.lora_dim):
|
||||
self.lora_A[i].requires_grad = False
|
||||
self.lora_B[i].requires_grad = False
|
||||
|
||||
# make lora_A
|
||||
if trainable_rank > 0:
|
||||
lora_A_nt1 = [self.lora_A[:trainable_rank].detach()]
|
||||
else:
|
||||
lora_A_nt1 = []
|
||||
|
||||
lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit]
|
||||
|
||||
if trainable_rank < self.lora_dim - self.unit:
|
||||
lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()]
|
||||
else:
|
||||
lora_A_nt2 = []
|
||||
|
||||
lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0)
|
||||
|
||||
# make lora_B
|
||||
if trainable_rank > 0:
|
||||
lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()]
|
||||
else:
|
||||
lora_B_nt1 = []
|
||||
|
||||
lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit]
|
||||
|
||||
if trainable_rank < self.lora_dim - self.unit:
|
||||
lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()]
|
||||
else:
|
||||
lora_B_nt2 = []
|
||||
|
||||
lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1)
|
||||
|
||||
# print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
|
||||
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
||||
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
||||
|
||||
# calculate with lora_A and lora_B
|
||||
if self.is_conv2d_3x3:
|
||||
@@ -116,13 +98,13 @@ class DyLoRAModule(torch.nn.Module):
|
||||
else:
|
||||
ab = x
|
||||
if self.is_conv2d:
|
||||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2)
|
||||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
||||
|
||||
ab = torch.nn.functional.linear(ab, lora_A)
|
||||
ab = torch.nn.functional.linear(ab, lora_B)
|
||||
|
||||
if self.is_conv2d:
|
||||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:])
|
||||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
||||
|
||||
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||||
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||||
@@ -131,34 +113,52 @@ class DyLoRAModule(torch.nn.Module):
|
||||
return result
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
# state dictを通常のLoRAと同じにする
|
||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
||||
# state dictを通常のLoRAと同じにする:
|
||||
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||||
sd = super().state_dict(destination, prefix, keep_vars)
|
||||
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_A")
|
||||
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight
|
||||
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_B")
|
||||
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||||
state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight
|
||||
|
||||
return state_dict
|
||||
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
||||
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
key_a = f"{self.lora_name}.lora_A.{i}"
|
||||
key_b = f"{self.lora_name}.lora_B.{i}"
|
||||
if key_a in sd:
|
||||
sd.pop(key_a)
|
||||
sd.pop(key_b)
|
||||
else:
|
||||
break
|
||||
i += 1
|
||||
return sd
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# 通常のLoRAと同じstate dictを読み込めるようにする
|
||||
state_dict = state_dict.copy()
|
||||
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight")
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||||
|
||||
if lora_A_weight is None or lora_B_weight is None:
|
||||
if strict:
|
||||
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||||
else:
|
||||
return
|
||||
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||
state_dict[self.lora_name + ".lora_A"] = lora_A_weight
|
||||
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight")
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||
state_dict[self.lora_name + ".lora_B"] = lora_B_weight
|
||||
|
||||
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
|
||||
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
|
||||
|
||||
super().load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user