Update URAE initialization, conversion. Add tests

This commit is contained in:
rockerBOO
2025-05-08 19:05:08 -04:00
parent e0f1ae0f2c
commit 89b6f8bcea
3 changed files with 427 additions and 102 deletions

View File

@@ -1,6 +1,7 @@
import torch
import math
import warnings
from torch import Tensor
from typing import Optional
from library.incremental_pca import IncrementalPCA
from dataclasses import dataclass
@@ -108,73 +109,73 @@ def initialize_urae(
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
org_module_device = org_module.weight.device
org_module_weight_dtype = org_module.weight.data.dtype
org_module_requires_grad = org_module.weight.requires_grad
# Store original device, dtype, and requires_grad status
orig_device = org_module.weight.device
orig_dtype = org_module.weight.data.dtype
orig_requires_grad = org_module.weight.requires_grad
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
# Determine device and dtype to work with
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
assert isinstance(device, torch.device), f"Invalid device type: {device}"
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
# Move original weight to chosen device and use float32 for numerical stability
weight = org_module.weight.data.to(device, dtype=torch.float32)
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
if use_ipca:
# For URAE we need all components to get the "residual" ones
ipca = IncrementalPCA(
n_components=None, # Get all components
n_components=None,
batch_size=1024,
lowrank=use_lowrank,
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
lowrank_niter=lowrank_niter,
lowrank_seed=lowrank_seed,
)
ipca.fit(weight)
# For URAE, use the LAST/SMALLEST singular values
total_rank = min(weight.shape[0], weight.shape[1])
V_full = ipca.components_.T # [out_features, total_rank]
S_full = ipca.singular_values_ # [total_rank]
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
S_full = ipca.singular_values_
V_full = ipca.components_.T # Shape: [out_features, total_rank]
# Get the smallest singular values and vectors
Vr = V_full[:, -rank:] # Last rank left singular vectors
Sr = S_full[-rank:] # Last rank singular values
Sr /= rank
# To get Uhr (last rank right singular vectors), transform basis vectors
# Get identity matrix to transform for right singular vectors
identity = torch.eye(weight.shape[1], device=weight.device)
Uhr_full = ipca.transform(identity).T # [total_rank, in_features]
Uhr = Uhr_full[-rank:] # Last rank right singular vectors
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
# Extract the last 'rank' components (the minor/smallest ones)
Sr = S_full[-rank:]
Vr = V_full[:, -rank:]
Uhr = Uhr_full[-rank:]
# Scale singular values
Sr = Sr / rank
else:
# Standard SVD approach
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, -rank:]
# Direct SVD approach
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
# Extract the minor components (smallest singular values)
Sr = S[-rank:]
Sr /= rank
Uhr = Uh[-rank:, :]
Vr = U[:, -rank:]
Uhr = Vh[-rank:]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Scale singular values
Sr = Sr / rank
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Create the low-rank adapter matrices by splitting the minor components
# Down matrix: scaled right singular vectors with singular values
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
# Verify shapes match expected
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
# Up matrix: scaled left singular vectors with singular values
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
# Assign to LoRA modules
lora_down.weight.data = down_matrix.to(device=device, dtype=dtype)
lora_up.weight.data = up_matrix.to(device=device, dtype=dtype)
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
org_module.weight.requires_grad = org_module_requires_grad
# Update the original weight by removing the minor components
# This is equivalent to keeping only the major components
modified_weight = weight - scale * (up_matrix @ down_matrix)
org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype)
org_module.weight.requires_grad = orig_requires_grad
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
@@ -269,3 +270,107 @@ def initialize_pissa(
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
org_module.weight.requires_grad = org_module_requires_grad
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# Take the top 2*r singular values (as suggested in the paper)
rank = rank * 2
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
# Create new LoRA matrices
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
# These matrices can now be used as standard LoRA weights
return new_up, new_down
def convert_urae_to_standard_lora(
trained_up: Tensor,
trained_down: Tensor,
orig_up: Tensor,
orig_down: Tensor,
initial_alpha: float | None = None,
rank: int | None = None,
):
"""
Convert URAE trained weights to standard LoRA format
Args:
trained_up: The trained URAE Up matrix
trained_down: The trained URAE Down matrix
orig_up: The original up matrix before training
orig_down: The original down matrix before training
initial_alpha: The alpha value used during URAE training (if any)
rank: The rank for the standard LoRA (if None, uses the rank of trained_A)
Returns:
lora_up: Standard LoRA up matrix
lora_down: Standard LoRA down matrix
alpha: Appropriate alpha value for the LoRA
"""
# Calculate the weight delta
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# Perform SVD on the delta
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
# If rank is not specified, use the same rank as the trained matrices
if rank is None:
rank = trained_up.shape[1]
else:
# Ensure we don't exceed available singular values
rank = min(rank, len(S))
# Create standard LoRA matrices using top singular values
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
# Method 1: Preserve the Frobenius norm of the delta
original_effect: float = torch.norm(delta_w, p="fro").item()
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
# The scaling factor in lora is (alpha/r), so:
# alpha/r × ||AB|| = ||delta_W||
# alpha = r × ||delta_W|| / ||AB||
if unscaled_lora_effect > 0:
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
else:
norm_based_alpha = 1.0 # Fallback
# Method 2: If initial_alpha is provided, adjust based on rank change
if initial_alpha is not None:
initial_rank = trained_up.shape[1]
# Scale alpha proportionally if rank changed
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
else:
rank_adjusted_alpha = None
# Choose the appropriate alpha
if rank_adjusted_alpha is not None:
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
# Cap the difference to avoid extreme values
alpha = rank_adjusted_alpha
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
if norm_based_alpha > 0:
max_factor = 5.0 # Allow up to 5x difference
upper_bound = norm_based_alpha * max_factor
lower_bound = norm_based_alpha / max_factor
alpha = min(max(alpha, lower_bound), upper_bound)
else:
# Use norm-based alpha
alpha = norm_based_alpha
# Round to a clean value for better usability
alpha = round(alpha, 2)
# Ensure alpha is positive and within reasonable bounds
alpha = max(0.1, min(alpha, 1024.0))
return lora_up, lora_down, alpha

View File

@@ -19,7 +19,14 @@ from tqdm import tqdm
import re
from library.utils import setup_logging
from library.device_utils import clean_memory_on_device
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae, initialize_parse_opts
from library.network_utils import (
initialize_lora,
initialize_pissa,
initialize_urae,
initialize_parse_opts,
convert_pissa_to_standard_lora,
convert_urae_to_standard_lora,
)
setup_logging()
import logging
@@ -49,6 +56,7 @@ class LoRAModule(torch.nn.Module):
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
initialize: Optional[str] = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -101,7 +109,6 @@ class LoRAModule(torch.nn.Module):
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
"""
Initialize the weights for the LoRA
@@ -113,12 +120,16 @@ class LoRAModule(torch.nn.Module):
if initialize is not None:
params = initialize_parse_opts(initialize)
if initialize[:4] == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
initialize_urae(
org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)
)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif initialize[:5] == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
initialize_pissa(
org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)
)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
@@ -149,7 +160,6 @@ class LoRAModule(torch.nn.Module):
self._org_lora_up = self._org_lora_up.to("cpu")
self._org_lora_down = self._org_lora_down.to("cpu")
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
@@ -607,7 +617,7 @@ def create_network(
if verbose is not None:
verbose = True if verbose == "True" else False
# Computation device, used in initialization
# Computation device, used in initialization
comp_device = kwargs.get("comp_device", None)
if comp_device is not None:
comp_device = torch.device(comp_device)
@@ -922,7 +932,9 @@ class LoRANetwork(torch.nn.Module):
elif "single" in lora_name and "linear1" in lora_name:
split_dims = [3072] * 3 + [12288]
assert module_class is LoRAModule or module_class is LoRAInfModule, f"Module class is not valid {type(module_class)}"
assert module_class is LoRAModule or module_class is LoRAInfModule, (
f"Module class is not valid {type(module_class)}"
)
lora = module_class(
lora_name,
child_module,
@@ -960,7 +972,7 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"Initialize Text Encoder LoRA using {initialize}")
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index+1}.")
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index + 1}.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped
@@ -1313,66 +1325,33 @@ class LoRANetwork(torch.nn.Module):
# Need to decompose the parameters into a LoRA format
if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"):
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# Take the top 2*r singular values (as suggested in the paper)
rank = rank * 2
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
# Create new LoRA matrices
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
# These matrices can now be used as standard LoRA weights
return new_up, new_down
def convert_urae_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
# Calculate ΔW = A'B' - AB
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# For URAE, we want to focus on the smallest singular values
# Take the bottom rank*2 singular values (opposite of PiSSA which takes the top ones)
total_rank = len(S)
rank_to_use = min(rank * 2, total_rank)
if rank_to_use < total_rank:
# Use the smallest singular values and vectors
selected_U = U[:, -rank_to_use:]
selected_S = S[-rank_to_use:]
selected_V = V[-rank_to_use:, :]
else:
# If we'd use all values, just use the standard approach but with a note
print("Warning: Requested rank is too large for URAE specialty, using all singular values")
selected_U = U
selected_S = S
selected_V = V
# Create new LoRA matrices
new_up = selected_U @ torch.diag(torch.sqrt(selected_S))
new_down = torch.diag(torch.sqrt(selected_S)) @ selected_V
# These matrices can now be used as standard LoRA weights
return new_up, new_down
with torch.no_grad():
progress = tqdm(total=len(loras), desc="Converting")
for lora in loras:
lora_up_key = f"{lora.lora_name}.lora_up.weight"
lora_down_key = f"{lora.lora_name}.lora_down.weight"
lora_alpha_key = f"{lora.lora_name}.alpha"
lora_up = state_dict[lora_up_key]
lora_down = state_dict[lora_down_key]
if self.initialize[:4] == "urae":
up, down = convert_urae_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
up, down, alpha = convert_urae_to_standard_lora(
lora_up,
lora_down,
lora._org_lora_up.to(lora_up.device),
lora._org_lora_down.to(lora_up.device),
initial_alpha=lora.alpha.item(),
rank=lora.lora_dim,
)
state_dict[lora_alpha_key] = torch.tensor(alpha).to(dtype=lora.alpha.dtype)
elif self.initialize[:5] == "pissa":
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
up, down = convert_pissa_to_standard_lora(
lora_up,
lora_down,
lora._org_lora_up.to(lora_up.device),
lora._org_lora_down.to(lora_up.device),
lora.lora_dim,
)
# TODO: Capture option if we should offload
# offload to CPU

View File

@@ -0,0 +1,241 @@
import pytest
import torch
from library.network_utils import convert_urae_to_standard_lora
class TestConvertURAEToStandardLoRA:
@pytest.fixture
def sample_matrices(self):
"""Create sample matrices for testing"""
# Original up matrix (4x2)
orig_up = torch.tensor([
[1.0, 2.0],
[3.0, 4.0],
[5.0, 6.0],
[7.0, 8.0]
])
# Original down matrix (2x6)
orig_down = torch.tensor([
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
[0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
])
# Trained up matrix (4x2) - same shape as orig_up but with changed values
trained_up = torch.tensor([
[1.1, 2.1],
[3.1, 4.1],
[5.1, 6.1],
[7.1, 8.1]
])
# Trained down matrix (2x6) - same shape as orig_down but with changed values
trained_down = torch.tensor([
[0.15, 0.25, 0.35, 0.45, 0.55, 0.65],
[0.75, 0.85, 0.95, 1.05, 1.15, 1.25]
])
return {
'orig_up': orig_up,
'orig_down': orig_down,
'trained_up': trained_up,
'trained_down': trained_down
}
def test_basic_conversion(self, sample_matrices):
"""Test the basic functionality of convert_urae_to_standard_lora"""
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
sample_matrices['trained_up'],
sample_matrices['trained_down'],
sample_matrices['orig_up'],
sample_matrices['orig_down']
)
# Check shapes
assert lora_up.shape[0] == sample_matrices['trained_up'].shape[0] # Same number of rows as trained_up
assert lora_up.shape[1] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
assert lora_down.shape[0] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
assert lora_down.shape[1] == sample_matrices['trained_down'].shape[1] # Same number of columns as trained_down
# Check alpha is a reasonable value
assert 0.1 <= alpha <= 1024.0
# Check that lora_up @ lora_down approximates the weight delta
delta = (sample_matrices['trained_up'] @ sample_matrices['trained_down']) - (sample_matrices['orig_up'] @ sample_matrices['orig_down'])
# The approximation should be close in Frobenius norm after scaling
lora_effect = lora_up @ lora_down
delta_norm = torch.norm(delta, p="fro").item()
lora_norm = torch.norm(lora_effect, p="fro").item()
# Either they are close, or the alpha scaling brings them close
scaled_lora_effect = (alpha / sample_matrices['trained_up'].shape[1]) * lora_effect
scaled_lora_norm = torch.norm(scaled_lora_effect, p="fro").item()
# At least one of these should be true
assert abs(delta_norm - lora_norm) < 1e-4 or abs(delta_norm - scaled_lora_norm) < 1e-4
def test_specified_rank(self, sample_matrices):
"""Test conversion with a specified rank"""
new_rank = 1 # Lower than trained_up's rank of 2
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
sample_matrices['trained_up'],
sample_matrices['trained_down'],
sample_matrices['orig_up'],
sample_matrices['orig_down'],
rank=new_rank
)
# Check that the new rank is used
assert lora_up.shape[1] == new_rank
assert lora_down.shape[0] == new_rank
# Should still produce a reasonable alpha
assert 0.1 <= alpha <= 1024.0
def test_with_initial_alpha(self, sample_matrices):
"""Test conversion with a specified initial alpha"""
initial_alpha = 16.0
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
sample_matrices['trained_up'],
sample_matrices['trained_down'],
sample_matrices['orig_up'],
sample_matrices['orig_down'],
initial_alpha=initial_alpha
)
# Alpha should be influenced by initial_alpha but may be adjusted
# Since we're using same rank, should be reasonably close to initial_alpha
assert 0.1 <= alpha <= 1024.0
# Should at least preserve the order of magnitude in typical cases
assert abs(alpha - initial_alpha) <= initial_alpha * 4.0
def test_large_initial_alpha(self, sample_matrices):
"""Test conversion with a very large initial alpha that should be capped"""
initial_alpha = 2000.0 # Larger than the 1024.0 cap
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
sample_matrices['trained_up'],
sample_matrices['trained_down'],
sample_matrices['orig_up'],
sample_matrices['orig_down'],
initial_alpha=initial_alpha
)
# Alpha should be capped at 1024.0
assert alpha <= 1024.0
def test_very_small_initial_alpha(self, sample_matrices):
"""Test conversion with a very small initial alpha that should be floored"""
initial_alpha = 0.01 # Smaller than the 0.1 floor
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
sample_matrices['trained_up'],
sample_matrices['trained_down'],
sample_matrices['orig_up'],
sample_matrices['orig_down'],
initial_alpha=initial_alpha
)
# Alpha should be floored at 0.1
assert alpha >= 0.1
def test_change_rank_with_initial_alpha(self, sample_matrices):
"""Test conversion with both rank change and initial alpha"""
initial_alpha = 16.0
new_rank = 1 # Half of original rank 2
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
sample_matrices['trained_up'],
sample_matrices['trained_down'],
sample_matrices['orig_up'],
sample_matrices['orig_down'],
initial_alpha=initial_alpha,
rank=new_rank
)
# Check shapes
assert lora_up.shape[1] == new_rank
assert lora_down.shape[0] == new_rank
# Alpha should be adjusted for the rank change (approx halved in this case)
expected_alpha = initial_alpha * (new_rank / sample_matrices['trained_up'].shape[1])
# Allow some tolerance for adjustments from norm-based capping
assert abs(alpha - expected_alpha) <= expected_alpha * 4.0 or alpha >= 0.1
def test_zero_delta(self):
"""Test conversion when delta is zero"""
# Create matrices where the delta will be zero
dim_in, rank, dim_out = 4, 2, 6
# Create identical matrices for original and trained
orig_up = torch.randn(dim_in, rank)
orig_down = torch.randn(rank, dim_out)
trained_up = orig_up.clone()
trained_down = orig_down.clone()
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
trained_up,
trained_down,
orig_up,
orig_down
)
# Should still return matrices of correct shape
assert lora_up.shape == (dim_in, rank)
assert lora_down.shape == (rank, dim_out)
# Alpha should be at least the minimum
assert alpha >= 0.1
def test_large_dimensions(self):
"""Test with larger matrix dimensions"""
dim_in, rank, dim_out = 100, 8, 200
orig_up = torch.randn(dim_in, rank)
orig_down = torch.randn(rank, dim_out)
trained_up = orig_up + 0.01 * torch.randn(dim_in, rank) # Small perturbation
trained_down = orig_down + 0.01 * torch.randn(rank, dim_out) # Small perturbation
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
trained_up,
trained_down,
orig_up,
orig_down
)
# Check shapes
assert lora_up.shape == (dim_in, rank)
assert lora_down.shape == (rank, dim_out)
# Should produce a reasonable alpha
assert 0.1 <= alpha <= 1024.0
def test_rank_exceeding_singular_values(self):
"""Test when requested rank exceeds available singular values"""
# Small matrices with limited rank
dim_in, rank, dim_out = 3, 2, 3
orig_up = torch.randn(dim_in, rank)
orig_down = torch.randn(rank, dim_out)
trained_up = orig_up + 0.1 * torch.randn(dim_in, rank)
trained_down = orig_down + 0.1 * torch.randn(rank, dim_out)
# Request rank larger than possible
too_large_rank = 10
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
trained_up,
trained_down,
orig_up,
orig_down,
rank=too_large_rank
)
# Rank should be limited to min(dim_in, dim_out, S.size)
max_possible_rank = min(dim_in, dim_out)
assert lora_up.shape[1] <= max_possible_rank
assert lora_down.shape[0] <= max_possible_rank