From 9ff32fd4c01668749058e1b7f2f2a87b3a5e6ca0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 13 Apr 2023 21:14:20 +0900 Subject: [PATCH] fix parameters are not freezed --- networks/dylora.py | 104 ++++++++++++++++++++++----------------------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index e588813e..c6c782fc 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -50,15 +50,17 @@ class DyLoRAModule(torch.nn.Module): kernel_size = org_module.kernel_size self.stride = org_module.stride self.padding = org_module.padding - self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim, *kernel_size))) - self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim, 1, 1))) + self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim, *kernel_size)) for _ in range(self.lora_dim)]) + self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1, 1, 1)) for _ in range(self.lora_dim)]) else: - self.lora_A = nn.Parameter(org_module.weight.new_zeros((self.lora_dim, in_dim))) - self.lora_B = nn.Parameter(org_module.weight.new_zeros((out_dim, self.lora_dim))) + self.lora_A = nn.ParameterList([org_module.weight.new_zeros((1, in_dim)) for _ in range(self.lora_dim)]) + self.lora_B = nn.ParameterList([org_module.weight.new_zeros((out_dim, 1)) for _ in range(self.lora_dim)]) # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_B) + for lora in self.lora_A: + torch.nn.init.kaiming_uniform_(lora, a=math.sqrt(5)) + for lora in self.lora_B: + torch.nn.init.zeros_(lora) self.multiplier = multiplier self.org_module = org_module # remove in applying @@ -76,38 +78,18 @@ class DyLoRAModule(torch.nn.Module): trainable_rank = trainable_rank - trainable_rank % self.unit # make sure the rank is a multiple of unit # 一部のパラメータを固定して、残りのパラメータを学習する + for i in range(0, trainable_rank): + self.lora_A[i].requires_grad = False + self.lora_B[i].requires_grad = False + for i in range(trainable_rank, trainable_rank + self.unit): + self.lora_A[i].requires_grad = True + self.lora_B[i].requires_grad = True + for i in range(trainable_rank + self.unit, self.lora_dim): + self.lora_A[i].requires_grad = False + self.lora_B[i].requires_grad = False - # make lora_A - if trainable_rank > 0: - lora_A_nt1 = [self.lora_A[:trainable_rank].detach()] - else: - lora_A_nt1 = [] - - lora_A_t = self.lora_A[trainable_rank : trainable_rank + self.unit] - - if trainable_rank < self.lora_dim - self.unit: - lora_A_nt2 = [self.lora_A[trainable_rank + self.unit :].detach()] - else: - lora_A_nt2 = [] - - lora_A = torch.cat(lora_A_nt1 + [lora_A_t] + lora_A_nt2, dim=0) - - # make lora_B - if trainable_rank > 0: - lora_B_nt1 = [self.lora_B[:, :trainable_rank].detach()] - else: - lora_B_nt1 = [] - - lora_B_t = self.lora_B[:, trainable_rank : trainable_rank + self.unit] - - if trainable_rank < self.lora_dim - self.unit: - lora_B_nt2 = [self.lora_B[:, trainable_rank + self.unit :].detach()] - else: - lora_B_nt2 = [] - - lora_B = torch.cat(lora_B_nt1 + [lora_B_t] + lora_B_nt2, dim=1) - - # print("lora_A", lora_A.size(), "lora_B", lora_B.size(), "x", x.size(), "result", result.size()) + lora_A = torch.cat(tuple(self.lora_A), dim=0) + lora_B = torch.cat(tuple(self.lora_B), dim=1) # calculate with lora_A and lora_B if self.is_conv2d_3x3: @@ -116,13 +98,13 @@ class DyLoRAModule(torch.nn.Module): else: ab = x if self.is_conv2d: - ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) + ab = ab.reshape(ab.size(0), ab.size(1), -1).transpose(1, 2) # (N, C, H, W) -> (N, H*W, C) ab = torch.nn.functional.linear(ab, lora_A) ab = torch.nn.functional.linear(ab, lora_B) if self.is_conv2d: - ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) + ab = ab.transpose(1, 2).reshape(ab.size(0), -1, *x.size()[2:]) # (N, H*W, C) -> (N, C, H, W) # 最後の項は、低rankをより大きくするためのスケーリング(じゃないかな) result = result + ab * self.scale * math.sqrt(self.lora_dim / (trainable_rank + self.unit)) @@ -131,34 +113,52 @@ class DyLoRAModule(torch.nn.Module): return result def state_dict(self, destination=None, prefix="", keep_vars=False): - # state dictを通常のLoRAと同じにする - state_dict = super().state_dict(destination, prefix, keep_vars) + # state dictを通常のLoRAと同じにする: + # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える + sd = super().state_dict(destination, prefix, keep_vars) - lora_A_weight = state_dict.pop(self.lora_name + ".lora_A") + lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) if self.is_conv2d and not self.is_conv2d_3x3: lora_A_weight = lora_A_weight.unsqueeze(-1).unsqueeze(-1) - state_dict[self.lora_name + ".lora_down.weight"] = lora_A_weight - lora_B_weight = state_dict.pop(self.lora_name + ".lora_B") + lora_B_weight = torch.cat(tuple(self.lora_B), dim=1) if self.is_conv2d and not self.is_conv2d_3x3: lora_B_weight = lora_B_weight.unsqueeze(-1).unsqueeze(-1) - state_dict[self.lora_name + ".lora_up.weight"] = lora_B_weight - return state_dict + sd[self.lora_name + ".lora_down.weight"] = lora_A_weight if keep_vars else lora_A_weight.detach() + sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() + + i = 0 + while True: + key_a = f"{self.lora_name}.lora_A.{i}" + key_b = f"{self.lora_name}.lora_B.{i}" + if key_a in sd: + sd.pop(key_a) + sd.pop(key_b) + else: + break + i += 1 + return sd def load_state_dict(self, state_dict, strict=True): # 通常のLoRAと同じstate dictを読み込めるようにする state_dict = state_dict.copy() - lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight") + lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) + lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) + + if lora_A_weight is None or lora_B_weight is None: + if strict: + raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") + else: + return + if self.is_conv2d and not self.is_conv2d_3x3: lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) - state_dict[self.lora_name + ".lora_A"] = lora_A_weight - - lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight") - if self.is_conv2d and not self.is_conv2d_3x3: lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) - state_dict[self.lora_name + ".lora_B"] = lora_B_weight + + state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))}) + state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))}) super().load_state_dict(state_dict, strict=strict)