Update URAE initialization, conversion. Add tests

This commit is contained in:
rockerBOO
2025-05-08 19:05:08 -04:00
parent e0f1ae0f2c
commit 89b6f8bcea
3 changed files with 427 additions and 102 deletions

View File

@@ -1,6 +1,7 @@
import torch
import math
import warnings
from torch import Tensor
from typing import Optional
from library.incremental_pca import IncrementalPCA
from dataclasses import dataclass
@@ -108,73 +109,73 @@ def initialize_urae(
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
org_module_device = org_module.weight.device
org_module_weight_dtype = org_module.weight.data.dtype
org_module_requires_grad = org_module.weight.requires_grad
# Store original device, dtype, and requires_grad status
orig_device = org_module.weight.device
orig_dtype = org_module.weight.data.dtype
orig_requires_grad = org_module.weight.requires_grad
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
# Determine device and dtype to work with
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
assert isinstance(device, torch.device), f"Invalid device type: {device}"
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
# Move original weight to chosen device and use float32 for numerical stability
weight = org_module.weight.data.to(device, dtype=torch.float32)
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
if use_ipca:
# For URAE we need all components to get the "residual" ones
ipca = IncrementalPCA(
n_components=None, # Get all components
n_components=None,
batch_size=1024,
lowrank=use_lowrank,
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
lowrank_niter=lowrank_niter,
lowrank_seed=lowrank_seed,
)
ipca.fit(weight)
# For URAE, use the LAST/SMALLEST singular values
total_rank = min(weight.shape[0], weight.shape[1])
V_full = ipca.components_.T # [out_features, total_rank]
S_full = ipca.singular_values_ # [total_rank]
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
S_full = ipca.singular_values_
V_full = ipca.components_.T # Shape: [out_features, total_rank]
# Get the smallest singular values and vectors
Vr = V_full[:, -rank:] # Last rank left singular vectors
Sr = S_full[-rank:] # Last rank singular values
Sr /= rank
# To get Uhr (last rank right singular vectors), transform basis vectors
# Get identity matrix to transform for right singular vectors
identity = torch.eye(weight.shape[1], device=weight.device)
Uhr_full = ipca.transform(identity).T # [total_rank, in_features]
Uhr = Uhr_full[-rank:] # Last rank right singular vectors
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
# Extract the last 'rank' components (the minor/smallest ones)
Sr = S_full[-rank:]
Vr = V_full[:, -rank:]
Uhr = Uhr_full[-rank:]
# Scale singular values
Sr = Sr / rank
else:
# Standard SVD approach
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, -rank:]
# Direct SVD approach
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
# Extract the minor components (smallest singular values)
Sr = S[-rank:]
Sr /= rank
Uhr = Uh[-rank:, :]
Vr = U[:, -rank:]
Uhr = Vh[-rank:]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Scale singular values
Sr = Sr / rank
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Create the low-rank adapter matrices by splitting the minor components
# Down matrix: scaled right singular vectors with singular values
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
# Verify shapes match expected
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
# Up matrix: scaled left singular vectors with singular values
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
# Assign to LoRA modules
lora_down.weight.data = down_matrix.to(device=device, dtype=dtype)
lora_up.weight.data = up_matrix.to(device=device, dtype=dtype)
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
org_module.weight.requires_grad = org_module_requires_grad
# Update the original weight by removing the minor components
# This is equivalent to keeping only the major components
modified_weight = weight - scale * (up_matrix @ down_matrix)
org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype)
org_module.weight.requires_grad = orig_requires_grad
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
@@ -269,3 +270,107 @@ def initialize_pissa(
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
org_module.weight.requires_grad = org_module_requires_grad
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# Take the top 2*r singular values (as suggested in the paper)
rank = rank * 2
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
# Create new LoRA matrices
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
# These matrices can now be used as standard LoRA weights
return new_up, new_down
def convert_urae_to_standard_lora(
trained_up: Tensor,
trained_down: Tensor,
orig_up: Tensor,
orig_down: Tensor,
initial_alpha: float | None = None,
rank: int | None = None,
):
"""
Convert URAE trained weights to standard LoRA format
Args:
trained_up: The trained URAE Up matrix
trained_down: The trained URAE Down matrix
orig_up: The original up matrix before training
orig_down: The original down matrix before training
initial_alpha: The alpha value used during URAE training (if any)
rank: The rank for the standard LoRA (if None, uses the rank of trained_A)
Returns:
lora_up: Standard LoRA up matrix
lora_down: Standard LoRA down matrix
alpha: Appropriate alpha value for the LoRA
"""
# Calculate the weight delta
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# Perform SVD on the delta
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
# If rank is not specified, use the same rank as the trained matrices
if rank is None:
rank = trained_up.shape[1]
else:
# Ensure we don't exceed available singular values
rank = min(rank, len(S))
# Create standard LoRA matrices using top singular values
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
# Method 1: Preserve the Frobenius norm of the delta
original_effect: float = torch.norm(delta_w, p="fro").item()
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
# The scaling factor in lora is (alpha/r), so:
# alpha/r × ||AB|| = ||delta_W||
# alpha = r × ||delta_W|| / ||AB||
if unscaled_lora_effect > 0:
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
else:
norm_based_alpha = 1.0 # Fallback
# Method 2: If initial_alpha is provided, adjust based on rank change
if initial_alpha is not None:
initial_rank = trained_up.shape[1]
# Scale alpha proportionally if rank changed
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
else:
rank_adjusted_alpha = None
# Choose the appropriate alpha
if rank_adjusted_alpha is not None:
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
# Cap the difference to avoid extreme values
alpha = rank_adjusted_alpha
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
if norm_based_alpha > 0:
max_factor = 5.0 # Allow up to 5x difference
upper_bound = norm_based_alpha * max_factor
lower_bound = norm_based_alpha / max_factor
alpha = min(max(alpha, lower_bound), upper_bound)
else:
# Use norm-based alpha
alpha = norm_based_alpha
# Round to a clean value for better usability
alpha = round(alpha, 2)
# Ensure alpha is positive and within reasonable bounds
alpha = max(0.1, min(alpha, 1024.0))
return lora_up, lora_down, alpha