Detach and clone original LoRA weights before training

This commit is contained in:
rockerBOO
2025-03-24 19:01:11 -04:00
parent 3356314002
commit 0bad5ae9f1

View File

@@ -85,15 +85,15 @@ class LoRAModule(torch.nn.Module):
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if initialize == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.lora_dim)
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data
self._org_lora_down = self.lora_down.weight.data
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data
self._org_lora_down = self.lora_down.weight.data
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
else:
initialize_lora(self.lora_down, self.lora_up)
else:
@@ -109,13 +109,13 @@ class LoRAModule(torch.nn.Module):
if initialize == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data
self._org_lora_down = lora_down.weight.data
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data
self._org_lora_down = lora_down.weight.data
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
else:
initialize_lora(lora_down, lora_up)
@@ -1101,19 +1101,19 @@ class LoRANetwork(torch.nn.Module):
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
# One approach is to do SVD on delta_w
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
# Take the top 2*r singular values (as suggested in the paper)
rank = rank * 2
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
# Create new LoRA matrices
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
# These matrices can now be used as standard LoRA weights
return new_up, new_down