fix parameters are not freezed

This commit is contained in:
Kohya S
2023-04-13 21:14:20 +09:00
parent a097c42579
commit 9ff32fd4c0

View File

@@ -50,15 +50,17 @@ class DyLoRAModule(torch.nn.Module):
kernel_size = org_module.kernel_size kernel_size = org_module.kernel_size
self.stride = org_module.stride self.stride = org_module.stride
self.padding = org_module.padding self.padding = org_module.padding
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size))) self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)])
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1))) self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)])
else: else:
self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim))) self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)])
self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim))) self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)])
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) for lora in self.lora_A:
torch.nn.init.zeros_(self.lora_B) torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5))
for lora in self.lora_B:
torch.nn.init.zeros_(lora)
self.multiplier = multiplier self.multiplier = multiplier
self.org_module = org_module # remove in applying self.org_module = org_module # remove in applying
@@ -76,38 +78,18 @@ class DyLoRAModule(torch.nn.Module):
trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit
# 一部のパラメータを固定して、残りのパラメータを学習する # 一部のパラメータを固定して、残りのパラメータを学習する
for i in range(0, trainable_rank):
self.lora_A[i].requires_grad = False
self.lora_B[i].requires_grad = False
for i in range(trainable_rank, trainable_rank + self.unit):
self.lora_A[i].requires_grad = True
self.lora_B[i].requires_grad = True
for i in range(trainable_rank + self.unit, self.lora_dim):
self.lora_A[i].requires_grad = False
self.lora_B[i].requires_grad = False
# make lora_A lora_A = torch.cat(tuple(self.lora_A), dim=0)
if trainable_rank > 0: lora_B = torch.cat(tuple(self.lora_B), dim=1)
lora_A_nt1 = [self.lora_A[:trainable_rank].detach()]
else:
lora_A_nt1 = []
lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit]
if trainable_rank < self.lora_dim - self.unit:
lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()]
else:
lora_A_nt2 = []
lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0)
# make lora_B
if trainable_rank > 0:
lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()]
else:
lora_B_nt1 = []
lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit]
if trainable_rank < self.lora_dim - self.unit:
lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()]
else:
lora_B_nt2 = []
lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1)
# print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size())
# calculate with lora_A and lora_B # calculate with lora_A and lora_B
if self.is_conv2d_3x3: if self.is_conv2d_3x3:
@@ -116,13 +98,13 @@ class DyLoRAModule(torch.nn.Module):
else: else:
ab = x ab = x
if self.is_conv2d: if self.is_conv2d:
ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C)
ab = torch.nn.functional.linear(ab, lora_A) ab = torch.nn.functional.linear(ab, lora_A)
ab = torch.nn.functional.linear(ab, lora_B) ab = torch.nn.functional.linear(ab, lora_B)
if self.is_conv2d: if self.is_conv2d:
ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W)
# 最後の項は、低rankをより大きくするためのスケーリングじゃないかな # 最後の項は、低rankをより大きくするためのスケーリングじゃないかな
result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit))
@@ -131,34 +113,52 @@ class DyLoRAModule(torch.nn.Module):
return result return result
def state_dict(self, destination=None, prefix="", keep_vars=False): def state_dict(self, destination=None, prefix="", keep_vars=False):
# state dictを通常のLoRAと同じにする # state dictを通常のLoRAと同じにする:
state_dict = super().state_dict(destination, prefix, keep_vars) # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
sd = super().state_dict(destination, prefix, keep_vars)
lora_A_weight = state_dict.pop(self.lora_name + ".lora_A") lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
if self.is_conv2d and not self.is_conv2d_3x3: if self.is_conv2d and not self.is_conv2d_3x3:
lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1)
state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight
lora_B_weight = state_dict.pop(self.lora_name + ".lora_B") lora_B_weight = torch.cat(tuple(self.lora_B), dim=1)
if self.is_conv2d and not self.is_conv2d_3x3: if self.is_conv2d and not self.is_conv2d_3x3:
lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1)
state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight
return state_dict sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach()
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
i = 0
while True:
key_a = f"{self.lora_name}.lora_A.{i}"
key_b = f"{self.lora_name}.lora_B.{i}"
if key_a in sd:
sd.pop(key_a)
sd.pop(key_b)
else:
break
i += 1
return sd
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
# 通常のLoRAと同じstate dictを読み込めるようにする # 通常のLoRAと同じstate dictを読み込めるようにする
state_dict = state_dict.copy() state_dict = state_dict.copy()
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight") lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
if lora_A_weight is None or lora_B_weight is None:
if strict:
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
else:
return
if self.is_conv2d and not self.is_conv2d_3x3: if self.is_conv2d and not self.is_conv2d_3x3:
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
state_dict[self.lora_name + ".lora_A"] = lora_A_weight
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight")
if self.is_conv2d and not self.is_conv2d_3x3:
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
state_dict[self.lora_name + ".lora_B"] = lora_B_weight
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
super().load_state_dict(state_dict, strict=strict) super().load_state_dict(state_dict, strict=strict)