Remove lowrank for URAE

This commit is contained in:
rockerBOO
2025-06-03 18:31:52 -04:00
parent 2e2b07fa1b
commit ed46280407
2 changed files with 10 additions and 48 deletions

View File

@@ -3,7 +3,6 @@ import math
import warnings
from torch import Tensor
from typing import Optional
from library.incremental_pca import IncrementalPCA
from dataclasses import dataclass
@@ -24,7 +23,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
- "pissa" -> Default PiSSA with lowrank=True, niter=4
- "pissa_niter_4" -> PiSSA with niter=4
- "pissa_lowrank_false" -> PiSSA without lowrank
- "urae_..." -> Same options but for URAE
Args:
key: String key to parse
@@ -80,11 +78,7 @@ def initialize_urae(
rank: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
use_ipca: bool = False,
use_lowrank: bool = True,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
**kwargs,
):
# Store original device, dtype, and requires_grad status
orig_device = org_module.weight.device
@@ -98,45 +92,17 @@ def initialize_urae(
# Move original weight to chosen device and use float32 for numerical stability
weight = org_module.weight.data.to(device, dtype=torch.float32)
with torch.autocast(device.type), torch.no_grad():
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
if use_ipca:
ipca = IncrementalPCA(
n_components=None,
batch_size=1024,
lowrank=use_lowrank,
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
lowrank_niter=lowrank_niter,
lowrank_seed=lowrank_seed,
)
ipca.fit(weight)
with torch.no_grad():
# Direct SVD approach
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
S_full = ipca.singular_values_
V_full = ipca.components_.T # Shape: [out_features, total_rank]
# Extract the minor components (smallest singular values)
Sr = S[-rank:]
Vr = U[:, -rank:]
Uhr = Vh[-rank:]
# Get identity matrix to transform for right singular vectors
identity = torch.eye(weight.shape[1], device=weight.device)
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
# Extract the last 'rank' components (the minor/smallest ones)
Sr = S_full[-rank:]
Vr = V_full[:, -rank:]
Uhr = Uhr_full[-rank:]
# Scale singular values
Sr = Sr / rank
else:
# Direct SVD approach
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
# Extract the minor components (smallest singular values)
Sr = S[-rank:]
Vr = U[:, -rank:]
Uhr = Vh[-rank:]
# Scale singular values
Sr = Sr / rank
# Scale singular values
Sr = Sr / rank
# Create the low-rank adapter matrices by splitting the minor components
# Down matrix: scaled right singular vectors with singular values

View File

@@ -1,4 +0,0 @@
import torch
import pytest
from library.network_utils import initialize_pissa