fix load_state_dict failed in dylora

This commit is contained in:
Kohya S
2023-04-14 22:13:26 +09:00
parent 06a9f51431
commit 92332eb96e

View File

@@ -115,7 +115,7 @@ class DyLoRAModule(torch.nn.Module):
def state_dict(self, destination=None, prefix="", keep_vars=False): def state_dict(self, destination=None, prefix="", keep_vars=False):
# state dictを通常のLoRAと同じにする: # state dictを通常のLoRAと同じにする:
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える # nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
sd = super().state_dict(destination, prefix, keep_vars) sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0) lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
if self.is_conv2d and not self.is_conv2d_3x3: if self.is_conv2d and not self.is_conv2d_3x3:
@@ -129,7 +129,7 @@ class DyLoRAModule(torch.nn.Module):
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach() sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
i = 0 i = 0
while True: while True:
key_a = f"{self.lora_name}.lora_A.{i}" key_a = f"{self.lora_name}.lora_A.{i}"
key_b = f"{self.lora_name}.lora_B.{i}" key_b = f"{self.lora_name}.lora_B.{i}"
if key_a in sd: if key_a in sd:
@@ -140,10 +140,8 @@ class DyLoRAModule(torch.nn.Module):
i += 1 i += 1
return sd return sd
def load_state_dict(self, state_dict, strict=True): def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
# 通常のLoRAと同じstate dictを読み込めるようにする # 通常のLoRAと同じstate dictを読み込めるようにするこの方法はchatGPTに聞いた
state_dict = state_dict.copy()
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None) lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None) lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
@@ -152,15 +150,19 @@ class DyLoRAModule(torch.nn.Module):
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found") raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
else: else:
return return
if self.is_conv2d and not self.is_conv2d_3x3: if self.is_conv2d and not self.is_conv2d_3x3:
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1) lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1) lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))}) state_dict.update(
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))}) {f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
)
state_dict.update(
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
)
super().load_state_dict(state_dict, strict=strict) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):