From ed46280407980dad176989187524bdc851b03d2d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 3 Jun 2025 18:31:52 -0400 Subject: [PATCH] Remove lowrank for URAE --- library/network_utils.py | 54 ++++++----------------------- tests/library/test_network_utils.py | 4 --- 2 files changed, 10 insertions(+), 48 deletions(-) delete mode 100644 tests/library/test_network_utils.py diff --git a/library/network_utils.py b/library/network_utils.py index 5c6e7e42..6edd0ff8 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -3,7 +3,6 @@ import math import warnings from torch import Tensor from typing import Optional -from library.incremental_pca import IncrementalPCA from dataclasses import dataclass @@ -24,7 +23,6 @@ def initialize_parse_opts(key: str) -> InitializeParams: - "pissa" -> Default PiSSA with lowrank=True, niter=4 - "pissa_niter_4" -> PiSSA with niter=4 - "pissa_lowrank_false" -> PiSSA without lowrank - - "urae_..." -> Same options but for URAE Args: key: String key to parse @@ -80,11 +78,7 @@ def initialize_urae( rank: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - use_ipca: bool = False, - use_lowrank: bool = True, - lowrank_q: Optional[int] = None, - lowrank_niter: int = 4, - lowrank_seed: Optional[int] = None, + **kwargs, ): # Store original device, dtype, and requires_grad status orig_device = org_module.weight.device @@ -98,45 +92,17 @@ def initialize_urae( # Move original weight to chosen device and use float32 for numerical stability weight = org_module.weight.data.to(device, dtype=torch.float32) - with torch.autocast(device.type), torch.no_grad(): - # Perform SVD decomposition (either directly or with IPCA for memory efficiency) - if use_ipca: - ipca = IncrementalPCA( - n_components=None, - batch_size=1024, - lowrank=use_lowrank, - lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), - lowrank_niter=lowrank_niter, - lowrank_seed=lowrank_seed, - ) - ipca.fit(weight) + with torch.no_grad(): + # Direct SVD approach + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) - # Extract singular values and vectors, focusing on the minor components (smallest singular values) - S_full = ipca.singular_values_ - V_full = ipca.components_.T # Shape: [out_features, total_rank] + # Extract the minor components (smallest singular values) + Sr = S[-rank:] + Vr = U[:, -rank:] + Uhr = Vh[-rank:] - # Get identity matrix to transform for right singular vectors - identity = torch.eye(weight.shape[1], device=weight.device) - Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features] - - # Extract the last 'rank' components (the minor/smallest ones) - Sr = S_full[-rank:] - Vr = V_full[:, -rank:] - Uhr = Uhr_full[-rank:] - - # Scale singular values - Sr = Sr / rank - else: - # Direct SVD approach - U, S, Vh = torch.linalg.svd(weight, full_matrices=False) - - # Extract the minor components (smallest singular values) - Sr = S[-rank:] - Vr = U[:, -rank:] - Uhr = Vh[-rank:] - - # Scale singular values - Sr = Sr / rank + # Scale singular values + Sr = Sr / rank # Create the low-rank adapter matrices by splitting the minor components # Down matrix: scaled right singular vectors with singular values diff --git a/tests/library/test_network_utils.py b/tests/library/test_network_utils.py deleted file mode 100644 index fb29b8b6..00000000 --- a/tests/library/test_network_utils.py +++ /dev/null @@ -1,4 +0,0 @@ -import torch -import pytest -from library.network_utils import initialize_pissa -