mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Detach and clone original LoRA weights before training
This commit is contained in:
@@ -85,15 +85,15 @@ class LoRAModule(torch.nn.Module):
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.lora_dim)
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data
|
||||
self._org_lora_down = self.lora_down.weight.data
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data
|
||||
self._org_lora_down = self.lora_down.weight.data
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(self.lora_down, self.lora_up)
|
||||
else:
|
||||
@@ -109,13 +109,13 @@ class LoRAModule(torch.nn.Module):
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data
|
||||
self._org_lora_down = lora_down.weight.data
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data
|
||||
self._org_lora_down = lora_down.weight.data
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(lora_down, lora_up)
|
||||
|
||||
@@ -1101,19 +1101,19 @@ class LoRANetwork(torch.nn.Module):
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
# One approach is to do SVD on delta_w
|
||||
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
|
||||
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||
|
||||
|
||||
# Create new LoRA matrices
|
||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
|
||||
# These matrices can now be used as standard LoRA weights
|
||||
return new_up, new_down
|
||||
|
||||
|
||||
Reference in New Issue
Block a user