Add lowrank SVD for PiSSA. Implement URAE conversion

This commit is contained in:
rockerBOO
2025-05-07 23:26:26 -04:00
parent d0eb3b5c79
commit ef8371243b
2 changed files with 250 additions and 51 deletions

View File

@@ -2,6 +2,90 @@ import torch
import math
import warnings
from typing import Optional
from library.incremental_pca import IncrementalPCA
from dataclasses import dataclass
@dataclass
class InitializeParams:
"""Parameters for initialization methods (PiSSA, URAE)"""
use_ipca: bool = False
use_lowrank: bool = True
lowrank_q: Optional[int] = None
lowrank_niter: int = 4
lowrank_seed: Optional[int] = None
def initialize_parse_opts(key: str) -> InitializeParams:
"""
Parse initialization parameters from a string key.
Format examples:
- "pissa" -> Default PiSSA with lowrank=True, niter=4
- "pissa_niter_4" -> PiSSA with niter=4
- "pissa_lowrank_false" -> PiSSA without lowrank
- "pissa_ipca_true" -> PiSSA with IPCA
- "pissa_q_16" -> PiSSA with lowrank_q=16
- "pissa_seed_42" -> PiSSA with seed=42
- "urae_..." -> Same options but for URAE
Args:
key: String key to parse
Returns:
InitializeParams object with parsed parameters
"""
parts = key.lower().split("_")
# Extract the method (first part)
method = parts[0]
if method not in ["pissa", "urae"]:
raise ValueError(f"Unknown initialization method: {method}")
# Start with default parameters
params = InitializeParams()
# Parse the remaining parts
i = 1
while i < len(parts):
if parts[i] == "ipca":
if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]:
params.use_ipca = parts[i + 1] == "true"
i += 2
else:
params.use_ipca = True
i += 1
elif parts[i] == "lowrank":
if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]:
params.use_lowrank = parts[i + 1] == "true"
i += 2
else:
params.use_lowrank = True
i += 1
elif parts[i] == "niter":
if i + 1 < len(parts) and parts[i + 1].isdigit():
params.lowrank_niter = int(parts[i + 1])
i += 2
else:
i += 1
elif parts[i] == "q":
if i + 1 < len(parts) and parts[i + 1].isdigit():
params.lowrank_q = int(parts[i + 1])
i += 2
else:
i += 1
elif parts[i] == "seed":
if i + 1 < len(parts) and parts[i + 1].isdigit():
params.lowrank_seed = int(parts[i + 1])
i += 2
else:
i += 1
else:
# Skip unknown parameter
i += 1
return params
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
@@ -18,49 +102,79 @@ def initialize_urae(
rank: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
use_ipca: bool = False,
use_lowrank: bool = True,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
org_module_device = org_module.weight.device
org_module_weight_dtype = org_module.weight.data.dtype
org_module_requires_grad = org_module.weight.requires_grad
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
assert isinstance(device, torch.device), f"Invalid device type: {device}"
weight = org_module.weight.data.to(device, dtype=torch.float32)
with torch.autocast(device.type):
# SVD decomposition
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
if use_ipca:
# For URAE we need all components to get the "residual" ones
ipca = IncrementalPCA(
n_components=None, # Get all components
batch_size=1024,
lowrank=use_lowrank,
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals
lowrank_niter=lowrank_niter,
lowrank_seed=lowrank_seed,
)
ipca.fit(weight)
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
# For URAE, use the LAST/SMALLEST singular values
total_rank = min(weight.shape[0], weight.shape[1])
V_full = ipca.components_.T # [out_features, total_rank]
S_full = ipca.singular_values_ # [total_rank]
# Get the smallest singular values and vectors
Vr = V_full[:, -rank:] # Last rank left singular vectors
Sr = S_full[-rank:] # Last rank singular values
Sr /= rank
# To get Uhr (last rank right singular vectors), transform basis vectors
identity = torch.eye(weight.shape[1], device=weight.device)
Uhr_full = ipca.transform(identity).T # [total_rank, in_features]
Uhr = Uhr_full[-rank:] # Last rank right singular vectors
else:
# Standard SVD approach
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, -rank:]
Sr = S[-rank:]
Sr /= rank
Uhr = Uh[-rank:, :]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Verify shapes match expected
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
# Verify shapes match expected
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
org_module.weight.requires_grad = org_module_requires_grad
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
org_module.weight.requires_grad = org_module_requires_grad
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
@@ -72,24 +186,68 @@ def initialize_pissa(
rank: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
use_ipca: bool = False,
use_lowrank: bool = True,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
org_module_device = org_module.weight.device
org_module_weight_dtype = org_module.weight.data.dtype
org_module_requires_grad = org_module.weight.requires_grad
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
assert isinstance(device, torch.device), f"Invalid device type: {device}"
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
with torch.no_grad():
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, :rank]
Sr = S[:rank]
Sr /= rank
Uhr = Uh[:rank]
if use_ipca:
# Use Incremental PCA for large matrices
ipca = IncrementalPCA(
n_components=rank,
batch_size=1024,
lowrank=use_lowrank,
lowrank_q=lowrank_q if lowrank_q is not None else 2 * rank,
lowrank_niter=lowrank_niter,
lowrank_seed=lowrank_seed,
)
ipca.fit(weight)
# Extract principal components and singular values
Vr = ipca.components_.T # [out_features, rank]
Sr = ipca.singular_values_ # [rank]
Sr /= rank
# We need to get Uhr from transforming an identity matrix
identity = torch.eye(weight.shape[1], device=weight.device)
Uhr = ipca.transform(identity).T # [rank, in_features]
elif use_lowrank:
# Use low-rank SVD approximation which is faster
seed_enabled = lowrank_seed is not None
q_value = lowrank_q if lowrank_q is not None else 2 * rank
with torch.random.fork_rng(enabled=seed_enabled):
if seed_enabled:
torch.manual_seed(lowrank_seed)
U, S, V = torch.svd_lowrank(weight, q=q_value, niter=lowrank_niter)
Vr = U[:, :rank] # First rank left singular vectors
Sr = S[:rank] # First rank singular values
Sr /= rank
Uhr = V[:rank] # First rank right singular vectors
else:
# Standard SVD approach
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, :rank]
Sr = S[:rank]
Sr /= rank
Uhr = Uh[:rank]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))

View File

@@ -7,6 +7,7 @@
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
from dataclasses import asdict
import math
import os
from typing import Dict, List, Optional, Type, Union
@@ -109,32 +110,36 @@ class LoRAModule(torch.nn.Module):
device: device to run initialization computation on
"""
if self.split_dims is None:
if initialize == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
if initialize is not None:
params = initialize_parse_opts(initialize)
if initialize[:4] == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif initialize[:5] == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
else:
initialize_lora(self.lora_down, self.lora_up)
else:
assert isinstance(self.lora_down, torch.nn.ModuleList)
assert isinstance(self.lora_up, torch.nn.ModuleList)
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
if initialize == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
if initialize is not None:
params = initialize_parse_opts(initialize)
if initialize[:4] == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif initialize[:5] == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
else:
initialize_lora(lora_down, lora_up)
@@ -1305,7 +1310,8 @@ class LoRANetwork(torch.nn.Module):
state_dict = self.state_dict()
if self.initialize in ['pissa']:
# Need to decompose the parameters into a LoRA format
if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"):
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
@@ -1325,14 +1331,49 @@ class LoRANetwork(torch.nn.Module):
# These matrices can now be used as standard LoRA weights
return new_up, new_down
def convert_urae_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# For URAE, we want to focus on the smallest singular values
# Take the bottom rank*2 singular values (opposite of PiSSA which takes the top ones)
total_rank = len(S)
rank_to_use = min(rank * 2, total_rank)
if rank_to_use < total_rank:
# Use the smallest singular values and vectors
selected_U = U[:, -rank_to_use:]
selected_S = S[-rank_to_use:]
selected_V = V[-rank_to_use:, :]
else:
# If we'd use all values, just use the standard approach but with a note
print("Warning: Requested rank is too large for URAE specialty, using all singular values")
selected_U = U
selected_S = S
selected_V = V
# Create new LoRA matrices
new_up = selected_U @ torch.diag(torch.sqrt(selected_S))
new_down = torch.diag(torch.sqrt(selected_S)) @ selected_V
# These matrices can now be used as standard LoRA weights
return new_up, new_down
with torch.no_grad():
progress = tqdm(total=len(loras), desc="Convert PiSSA")
progress = tqdm(total=len(loras), desc="Converting")
for lora in loras:
lora_up_key = f"{lora.lora_name}.lora_up.weight"
lora_down_key = f"{lora.lora_name}.lora_down.weight"
lora_up = state_dict[lora_up_key]
lora_down = state_dict[lora_down_key]
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
if self.initialize[:4] == "urae":
up, down = convert_urae_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
elif self.initialize[:5] == "pissa":
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
# TODO: Capture option if we should offload
# offload to CPU
state_dict[lora_up_key] = up.detach()