mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix load_state_dict failed in dylora
This commit is contained in:
@@ -115,7 +115,7 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||||
# state dictを通常のLoRAと同じにする:
|
# state dictを通常のLoRAと同じにする:
|
||||||
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||||||
sd = super().state_dict(destination, prefix, keep_vars)
|
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||||
|
|
||||||
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
@@ -129,7 +129,7 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
while True:
|
while True:
|
||||||
key_a = f"{self.lora_name}.lora_A.{i}"
|
key_a = f"{self.lora_name}.lora_A.{i}"
|
||||||
key_b = f"{self.lora_name}.lora_B.{i}"
|
key_b = f"{self.lora_name}.lora_B.{i}"
|
||||||
if key_a in sd:
|
if key_a in sd:
|
||||||
@@ -140,10 +140,8 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
i += 1
|
i += 1
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def load_state_dict(self, state_dict, strict=True):
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
# 通常のLoRAと同じstate dictを読み込めるようにする
|
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
||||||
state_dict = state_dict.copy()
|
|
||||||
|
|
||||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||||||
|
|
||||||
@@ -152,15 +150,19 @@ class DyLoRAModule(torch.nn.Module):
|
|||||||
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||||
|
|
||||||
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
|
state_dict.update(
|
||||||
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
|
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
||||||
|
)
|
||||||
|
state_dict.update(
|
||||||
|
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
||||||
|
)
|
||||||
|
|
||||||
super().load_state_dict(state_dict, strict=strict)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
|
|||||||
Reference in New Issue
Block a user