mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix load_state_dict failed in dylora
This commit is contained in:
@@ -115,7 +115,7 @@ class DyLoRAModule(torch.nn.Module):
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
# state dictを通常のLoRAと同じにする:
|
||||
# nn.ParameterListは `.lora_A.0` みたいな名前になるので、forwardと同様にcatして入れ替える
|
||||
sd = super().state_dict(destination, prefix, keep_vars)
|
||||
sd = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
lora_A_weight = torch.cat(tuple(self.lora_A), dim=0)
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
@@ -129,7 +129,7 @@ class DyLoRAModule(torch.nn.Module):
|
||||
sd[self.lora_name + ".lora_up.weight"] = lora_B_weight if keep_vars else lora_B_weight.detach()
|
||||
|
||||
i = 0
|
||||
while True:
|
||||
while True:
|
||||
key_a = f"{self.lora_name}.lora_A.{i}"
|
||||
key_b = f"{self.lora_name}.lora_B.{i}"
|
||||
if key_a in sd:
|
||||
@@ -140,10 +140,8 @@ class DyLoRAModule(torch.nn.Module):
|
||||
i += 1
|
||||
return sd
|
||||
|
||||
def load_state_dict(self, state_dict, strict=True):
|
||||
# 通常のLoRAと同じstate dictを読み込めるようにする
|
||||
state_dict = state_dict.copy()
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
# 通常のLoRAと同じstate dictを読み込めるようにする:この方法はchatGPTに聞いた
|
||||
lora_A_weight = state_dict.pop(self.lora_name + ".lora_down.weight", None)
|
||||
lora_B_weight = state_dict.pop(self.lora_name + ".lora_up.weight", None)
|
||||
|
||||
@@ -152,15 +150,19 @@ class DyLoRAModule(torch.nn.Module):
|
||||
raise KeyError(f"{self.lora_name}.lora_down/up.weight is not found")
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
if self.is_conv2d and not self.is_conv2d_3x3:
|
||||
lora_A_weight = lora_A_weight.squeeze(-1).squeeze(-1)
|
||||
lora_B_weight = lora_B_weight.squeeze(-1).squeeze(-1)
|
||||
|
||||
state_dict.update({f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i]) for i in range(lora_A_weight.size(0))})
|
||||
state_dict.update({f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i]) for i in range(lora_B_weight.size(1))})
|
||||
state_dict.update(
|
||||
{f"{self.lora_name}.lora_A.{i}": nn.Parameter(lora_A_weight[i].unsqueeze(0)) for i in range(lora_A_weight.size(0))}
|
||||
)
|
||||
state_dict.update(
|
||||
{f"{self.lora_name}.lora_B.{i}": nn.Parameter(lora_B_weight[:, i].unsqueeze(1)) for i in range(lora_B_weight.size(1))}
|
||||
)
|
||||
|
||||
super().load_state_dict(state_dict, strict=strict)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user