From 0bad5ae9f16e235660e43101e842a380bc62c814 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 24 Mar 2025 19:01:11 -0400 Subject: [PATCH] Detach and clone original LoRA weights before training --- networks/lora_flux.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index f5c5f5eb..b9b11423 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -85,15 +85,15 @@ class LoRAModule(torch.nn.Module): self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) if initialize == "urae": - initialize_urae(org_module, self.lora_down, self.lora_up, self.lora_dim) + initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = self.lora_up.weight.data - self._org_lora_down = self.lora_down.weight.data + self._org_lora_up = self.lora_up.weight.data.detach().clone() + self._org_lora_down = self.lora_down.weight.data.detach().clone() elif initialize == "pissa": initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = self.lora_up.weight.data - self._org_lora_down = self.lora_down.weight.data + self._org_lora_up = self.lora_up.weight.data.detach().clone() + self._org_lora_down = self.lora_down.weight.data.detach().clone() else: initialize_lora(self.lora_down, self.lora_up) else: @@ -109,13 +109,13 @@ class LoRAModule(torch.nn.Module): if initialize == "urae": initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim) # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = lora_up.weight.data - self._org_lora_down = lora_down.weight.data + self._org_lora_up = lora_up.weight.data.detach().clone() + self._org_lora_down = lora_down.weight.data.detach().clone() elif initialize == "pissa": initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim) # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = lora_up.weight.data - self._org_lora_down = lora_down.weight.data + self._org_lora_up = lora_up.weight.data.detach().clone() + self._org_lora_down = lora_down.weight.data.detach().clone() else: initialize_lora(lora_down, lora_up) @@ -1101,19 +1101,19 @@ class LoRANetwork(torch.nn.Module): def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): # Calculate ΔW = A'B' - AB delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) - + # We need to create new low-rank matrices that represent this delta # One approach is to do SVD on delta_w U, S, V = torch.linalg.svd(delta_w, full_matrices=False) - + # Take the top 2*r singular values (as suggested in the paper) rank = rank * 2 rank = min(rank, len(S)) # Make sure we don't exceed available singular values - + # Create new LoRA matrices new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] - + # These matrices can now be used as standard LoRA weights return new_up, new_down