mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix parameters are not freezed
This commit is contained in:
@@ -50,15 +50,17 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
kernel_size = org_module.kernel_size
|
kernel_size = org_module.kernel_size
|
||||||
self.stride = org_module.stride
|
self.stride = org_module.stride
|
||||||
self.padding = org_module.padding
|
self.padding = org_module.padding
|
||||||
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size)))
|
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
|
||||||
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1)))
|
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
|
||||||
else:
|
else:
|
||||||
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim)))
|
self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
|
||||||
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim)))
|
self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
|
||||||
|
|
||||||
# same as microsoft's
|
# same as microsoft's
|
||||||
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
for lora in self.lora_A:
|
||||||
torch.nn.init.zeros_(self.lora_B)
|
torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
|
||||||
|
for lora in self.lora_B:
|
||||||
|
torch.nn.init.zeros_(lora)
|
||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
@@ -76,38 +78,18 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
|
||||||
|
|
||||||
# 一部のパラメータを固定して、残りのパラメータを学習する
|
# 一部のパラメータを固定して、残りのパラメータを学習する
|
||||||
|
for i in range(0, trainable_rank):
|
||||||
|
self.lora_A[i].requires_grad = False
|
||||||
|
self.lora_B[i].requires_grad = False
|
||||||
|
for i in range(trainable_rank, trainable_rank + self.unit):
|
||||||
|
self.lora_A[i].requires_grad = True
|
||||||
|
self.lora_B[i].requires_grad = True
|
||||||
|
for i in range(trainable_rank + self.unit, self.lora_dim):
|
||||||
|
self.lora_A[i].requires_grad = False
|
||||||
|
self.lora_B[i].requires_grad = False
|
||||||
|
|
||||||
# make lora_A
|
lora_A = torch.cat(tuple(self.lora_A), dim=0)
|
||||||
if trainable_rank > 0:
|
lora_B = torch.cat(tuple(self.lora_B), dim=1)
|
||||||
lora_A_nt1 = [self.lora_A[:trainable_rank].detach()]
|
|
||||||
else:
|
|
||||||
lora_A_nt1 = []
|
|
||||||
|
|
||||||
lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit]
|
|
||||||
|
|
||||||
if trainable_rank < self.lora_dim - self.unit:
|
|
||||||
lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()]
|
|
||||||
else:
|
|
||||||
lora_A_nt2 = []
|
|
||||||
|
|
||||||
lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0)
|
|
||||||
|
|
||||||
# make lora_B
|
|
||||||
if trainable_rank > 0:
|
|
||||||
lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()]
|
|
||||||
else:
|
|
||||||
lora_B_nt1 = []
|
|
||||||
|
|
||||||
lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit]
|
|
||||||
|
|
||||||
if trainable_rank < self.lora_dim - self.unit:
|
|
||||||
lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()]
|
|
||||||
else:
|
|
||||||
lora_B_nt2 = []
|
|
||||||
|
|
||||||
lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1)
|
|
||||||
|
|
||||||
# print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
|
|
||||||
|
|
||||||
# calculate with lora_A and lora_B
|
# calculate with lora_A and lora_B
|
||||||
if self.is_conv2d_3x3:
|
if self.is_conv2d_3x3:
|
||||||
@@ -116,13 +98,13 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
ab = x
|
ab = x
|
||||||
if self.is_conv2d:
|
if self.is_conv2d:
|
||||||
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2)
|
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
|
||||||
|
|
||||||
ab = torch.nn.functional.linear(ab, lora_A)
|
ab = torch.nn.functional.linear(ab, lora_A)
|
||||||
ab = torch.nn.functional.linear(ab, lora_B)
|
ab = torch.nn.functional.linear(ab, lora_B)
|
||||||
|
|
||||||
if self.is_conv2d:
|
if self.is_conv2d:
|
||||||
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:])
|
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
|
||||||
|
|
||||||
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
# 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな)
|
||||||
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
|
||||||
@@ -131,34 +113,52 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
# state dictを通常のLoRAと同じにする
|
# state dictを通常のLoRAと同じにする:
|
||||||
state_dict = super().state_dict(destination, prefix, keep_vars)
|
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||||||
|
sd = super().state_dict(destination, prefix, keep_vars)
|
||||||
|
|
||||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_A")
|
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
|
||||||
state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight
|
|
||||||
|
|
||||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_B")
|
lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
|
||||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
|
||||||
state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight
|
|
||||||
|
|
||||||
return state_dict
|
sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
|
||||||
|
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while True:
|
||||||
|
key_a = f"{self.lora_name}.lora_A.{i}"
|
||||||
|
key_b = f"{self.lora_name}.lora_B.{i}"
|
||||||
|
if key_a in sd:
|
||||||
|
sd.pop(key_a)
|
||||||
|
sd.pop(key_b)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
i += 1
|
||||||
|
return sd
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
def load_state_dict(self, state_dict, strict=True):
|
||||||
# 通常のLoRAと同じstate dictを読み込めるようにする
|
# 通常のLoRAと同じstate dictを読み込めるようにする
|
||||||
state_dict = state_dict.copy()
|
state_dict = state_dict.copy()
|
||||||
|
|
||||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight")
|
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||||||
|
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||||||
|
|
||||||
|
if lora_A_weight is None or lora_B_weight is None:
|
||||||
|
if strict:
|
||||||
|
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||||
state_dict[self.lora_name + ".lora_A"] = lora_A_weight
|
|
||||||
|
|
||||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight")
|
|
||||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
|
||||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||||
state_dict[self.lora_name + ".lora_B"] = lora_B_weight
|
|
||||||
|
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
|
||||||
|
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
|
||||||
|
|
||||||
super().load_state_dict(state_dict, strict=strict)
|
super().load_state_dict(state_dict, strict=strict)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user