From 92332eb96e152249c0cff09af7b33e91393426be Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 14 Apr 2023 22:13:26 +0900 Subject: [PATCH] fix load_state_dict failed in dylora --- networks/dylora.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index c6c782fc..90b509df 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -115,7 +115,7 @@ class DyLoRAModule(torch.nn.Module): def state_dict(self, destination=None, prefix="", keep_vars=False): # state dictを通常のLoRAと同じにする: # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える - sd = super().state_dict(destination, prefix, keep_vars) + sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) if self.is_conv2d and not self.is_conv2d_3x3: @@ -129,7 +129,7 @@ class DyLoRAModule(torch.nn.Module): sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() i = 0 - while True: + while True: key_a = f"{self.lora_name}.lora_A.{i}" key_b = f"{self.lora_name}.lora_B.{i}" if key_a in sd: @@ -140,10 +140,8 @@ class DyLoRAModule(torch.nn.Module): i += 1 return sd - def load_state_dict(self, state_dict, strict=True): - # 通常のLoRAと同じstate dictを読み込めるようにする - state_dict = state_dict.copy() - + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + # 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) @@ -152,15 +150,19 @@ class DyLoRAModule(torch.nn.Module): raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") else: return - + if self.is_conv2d and not self.is_conv2d_3x3: lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) - state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))}) - state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))}) + state_dict.update( + {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))} + ) + state_dict.update( + {f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))} + ) - super().load_state_dict(state_dict, strict=strict) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):