From 6c6f3171a2f205ebc5aeb6809a8be2568a1101b6 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 3 Jun 2025 17:47:59 -0400 Subject: [PATCH] Fix lowrank PISSA. Add more tests --- library/flux_train_utils.py | 2 +- library/network_utils.py | 29 +- tests/library/test_network_utils.py | 385 --------------- tests/library/test_network_utils_pissa.py | 569 ++++++++++++++++++++++ tests/library/test_network_utils_urae.py | 194 +++----- 5 files changed, 660 insertions(+), 519 deletions(-) create mode 100644 tests/library/test_network_utils_pissa.py diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index e5fb8163..c841f816 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype + args, noise_scheduler, latents, noise, device, dtype, num_timesteps=1000 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape sigmas = None diff --git a/library/network_utils.py b/library/network_utils.py index f58cee4e..5c6e7e42 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -12,7 +12,7 @@ class InitializeParams: """Parameters for initialization methods (PiSSA, URAE)""" use_lowrank: bool = False - lowrank_q: Optional[int] = None + # lowrank_q: Optional[int] = None lowrank_niter: int = 4 @@ -24,8 +24,6 @@ def initialize_parse_opts(key: str) -> InitializeParams: - "pissa" -> Default PiSSA with lowrank=True, niter=4 - "pissa_niter_4" -> PiSSA with niter=4 - "pissa_lowrank_false" -> PiSSA without lowrank - - "pissa_q_16" -> PiSSA with lowrank_q=16 - - "pissa_seed_42" -> PiSSA with seed=42 - "urae_..." -> Same options but for URAE Args: @@ -57,12 +55,7 @@ def initialize_parse_opts(key: str) -> InitializeParams: elif parts[i] == "niter": if i + 1 < len(parts) and parts[i + 1].isdigit(): params.lowrank_niter = int(parts[i + 1]) - i += 2 - else: - i += 1 - elif parts[i] == "q": - if i + 1 < len(parts) and parts[i + 1].isdigit(): - params.lowrank_q = int(parts[i + 1]) + params.use_lowrank = True i += 2 else: i += 1 @@ -173,7 +166,7 @@ def initialize_pissa( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, use_lowrank: bool = False, - lowrank_q: Optional[int] = None, + # lowrank_q: Optional[int] = None, lowrank_niter: int = 4, ): org_module_device = org_module.weight.device @@ -188,17 +181,17 @@ def initialize_pissa( with torch.no_grad(): if use_lowrank: - q_value = lowrank_q if lowrank_q is not None else 2 * rank - Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=q_value, niter=lowrank_niter) + # q_value = lowrank_q if lowrank_q is not None else 2 * rank + Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=rank, niter=lowrank_niter) Sr /= rank Uhr = Ur.t() else: # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) - Vr = V[:, : rank] - Sr = S[: rank] + Vr = V[:, :rank] + Sr = S[:rank] Sr /= rank - Uhr = Uh[: rank] + Uhr = Uh[:rank] down = torch.diag(torch.sqrt(Sr)) @ Uhr up = Vr @ torch.diag(torch.sqrt(Sr)) @@ -222,13 +215,15 @@ def initialize_pissa( org_module.weight.requires_grad = org_module_requires_grad -def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): +def convert_pissa_to_standard_lora( + trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int +): with torch.no_grad(): # Calculate ΔW = A'B' - AB delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) # We need to create new low-rank matrices that represent this delta - U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) + U, S, V = torch.linalg.svd(delta_w.to(trained_up.device, dtype=torch.float32), full_matrices=False) # Take the top 2*r singular values (as suggested in the paper) rank = rank * 2 diff --git a/tests/library/test_network_utils.py b/tests/library/test_network_utils.py index 363f2b85..fb29b8b6 100644 --- a/tests/library/test_network_utils.py +++ b/tests/library/test_network_utils.py @@ -2,388 +2,3 @@ import torch import pytest from library.network_utils import initialize_pissa - -def generate_synthetic_weights(org_weight, seed=42): - generator = torch.manual_seed(seed) - - # Base random normal distribution - weights = torch.randn_like(org_weight) - - # Add structured variance to mimic real-world weight matrices - # Techniques to create more realistic weight distributions: - - # 1. Block-wise variation - block_size = max(1, org_weight.shape[0] // 4) - for i in range(0, org_weight.shape[0], block_size): - block_end = min(i + block_size, org_weight.shape[0]) - block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling - weights[i:block_end, :] *= 1 + block_variation - - # 2. Sparse connectivity simulation - sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity - weights *= sparsity_mask.float() - - # 3. Magnitude decay - magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1) - weights *= magnitude_decay - - # 4. Add structured noise - structural_noise = torch.randn_like(org_weight) * 0.1 - weights += structural_noise - - # Normalize to have similar statistical properties to trained weights - weights = (weights - weights.mean()) / weights.std() - - return weights - - -def test_initialize_pissa_rank_constraints(): - # Test with different rank values - org_module = torch.nn.Linear(20, 10) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - torch.nn.init.xavier_uniform_(org_module.weight) - torch.nn.init.zeros_(org_module.bias) - - # Test with rank less than min dimension - lora_down = torch.nn.Linear(20, 3) - lora_up = torch.nn.Linear(3, 10) - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) - - # Test with rank equal to min dimension - lora_down = torch.nn.Linear(20, 10) - lora_up = torch.nn.Linear(10, 10) - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10) - - -def test_initialize_pissa_rank_limits(): - # Test rank limits - org_module = torch.nn.Linear(10, 5) - - # Test minimum rank (should work) - lora_down_min = torch.nn.Linear(10, 1) - lora_up_min = torch.nn.Linear(1, 5) - initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1) - - # Test maximum rank (rank = min(input_dim, output_dim)) - max_rank = min(10, 5) - lora_down_max = torch.nn.Linear(10, max_rank) - lora_up_max = torch.nn.Linear(max_rank, 5) - initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank) - - -def test_initialize_pissa_basic(): - # Create a simple linear layer - org_module = torch.nn.Linear(10, 5) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - torch.nn.init.xavier_uniform_(org_module.weight) - torch.nn.init.zeros_(org_module.bias) - - # Create LoRA layers with matching shapes - lora_down = torch.nn.Linear(10, 2) - lora_up = torch.nn.Linear(2, 5) - - # Store original weight for comparison - original_weight = org_module.weight.data.clone() - - # Call the initialization function - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) - - # Verify basic properties - assert lora_down.weight.data is not None - assert lora_up.weight.data is not None - assert org_module.weight.data is not None - - # Check that the weights have been modified - assert not torch.equal(original_weight, org_module.weight.data) - - -def test_initialize_pissa_with_lowrank(): - # Test with low-rank SVD option - org_module = torch.nn.Linear(50, 30) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - lora_down = torch.nn.Linear(50, 5) - lora_up = torch.nn.Linear(5, 30) - - original_weight = org_module.weight.data.clone() - - # Call with low-rank SVD enabled - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True) - - # Verify weights are changed - assert not torch.equal(original_weight, org_module.weight.data) - - -def test_initialize_pissa_custom_lowrank_params(): - # Test with custom low-rank parameters - org_module = torch.nn.Linear(30, 20) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - lora_down = torch.nn.Linear(30, 5) - lora_up = torch.nn.Linear(5, 20) - - # Test with custom q value and iterations - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_q=12, lowrank_niter=6) - - # Check basic validity - assert lora_down.weight.data is not None - assert lora_up.weight.data is not None - - -def test_initialize_pissa_device_handling(): - # Test different device scenarios - devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] - - for device in devices: - # Create modules on specific device - org_module = torch.nn.Linear(10, 5).to(device) - lora_down = torch.nn.Linear(10, 2).to(device) - lora_up = torch.nn.Linear(2, 5).to(device) - - # Test initialization with explicit device - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device) - - # Verify modules are on the correct device - assert org_module.weight.data.device.type == device.type - assert lora_down.weight.data.device.type == device.type - assert lora_up.weight.data.device.type == device.type - - # Test with IPCA - if device.type == "cpu": # IPCA might be slow on CPU for large matrices - org_module_small = torch.nn.Linear(20, 10).to(device) - lora_down_small = torch.nn.Linear(20, 3).to(device) - lora_up_small = torch.nn.Linear(3, 10).to(device) - - initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device) - - assert org_module_small.weight.data.device.type == device.type - - -def test_initialize_pissa_shape_mismatch(): - # Test with shape mismatch to ensure warning is printed - org_module = torch.nn.Linear(20, 10) - - # Intentionally mismatched shapes to test warning mechanism - lora_down = torch.nn.Linear(20, 5) # Different shape - lora_up = torch.nn.Linear(3, 15) # Different shape - - with pytest.warns(UserWarning): - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) - - -def test_initialize_pissa_scaling(): - # Test different scaling factors - scales = [0.0, 0.1, 1.0] - - for scale in scales: - org_module = torch.nn.Linear(10, 5) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - original_weight = org_module.weight.data.clone() - - lora_down = torch.nn.Linear(10, 2) - lora_up = torch.nn.Linear(2, 5) - - initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2) - - # Check that the weight modification follows the scaling - weight_diff = original_weight - org_module.weight.data - expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data) - - torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4) - - -def test_initialize_pissa_dtype(): - # Test with different data types - dtypes = [torch.float16, torch.float32, torch.float64] - - for dtype in dtypes: - org_module = torch.nn.Linear(10, 5).to(dtype=dtype) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - lora_down = torch.nn.Linear(10, 2) - lora_up = torch.nn.Linear(2, 5) - - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) - - # Verify output dtype matches input - assert org_module.weight.dtype == dtype - - -def test_initialize_pissa_svd_properties(): - # Verify SVD decomposition properties - org_module = torch.nn.Linear(20, 10) - lora_down = torch.nn.Linear(20, 3) - lora_up = torch.nn.Linear(3, 10) - - org_module.weight.data = generate_synthetic_weights(org_module.weight) - original_weight = org_module.weight.data.clone() - - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) - - # Reconstruct the weight - reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data) - - # Check reconstruction is close to original - torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4) - - -def test_initialize_pissa_dtype_preservation(): - # Test dtype preservation and conversion - dtypes = [torch.float16, torch.float32, torch.float64] - - for dtype in dtypes: - org_module = torch.nn.Linear(10, 5).to(dtype=dtype) - lora_down = torch.nn.Linear(10, 2).to(dtype=dtype) - lora_up = torch.nn.Linear(2, 5).to(dtype=dtype) - - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) - - assert org_module.weight.dtype == dtype - assert lora_down.weight.dtype == dtype - assert lora_up.weight.dtype == dtype - - # Test with explicit dtype - if dtype != torch.float16: # Skip float16 for computational stability in SVD - org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32) - lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32) - lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32) - - initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype) - - # Original module should be converted to specified dtype - assert org_module2.weight.dtype == torch.float32 - - -def test_initialize_pissa_numerical_stability(): - # Test with numerically challenging scenarios - scenarios = [ - torch.randn(20, 10) * 1e-5, # Small values - torch.randn(20, 10) * 1e5, # Large values - torch.ones(20, 10), # Uniform values - ] - - for i, weight_matrix in enumerate(scenarios): - org_module = torch.nn.Linear(20, 10) - org_module.weight.data = weight_matrix - - lora_down = torch.nn.Linear(20, 3) - lora_up = torch.nn.Linear(3, 10) - - try: - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) - - # Test IPCA as well - lora_down_ipca = torch.nn.Linear(20, 3) - lora_up_ipca = torch.nn.Linear(3, 10) - initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3) - except Exception as e: - pytest.fail(f"Initialization failed for scenario ({i}): {e}") - - -def test_initialize_pissa_scale_effects(): - # Test effect of different scaling factors - org_module = torch.nn.Linear(15, 10) - original_weight = torch.randn_like(org_module.weight.data) - org_module.weight.data = original_weight.clone() - - # Try different scales - scales = [0.0, 0.01, 0.1, 1.0] - - for scale in scales: - # Reset to original weights - org_module.weight.data = original_weight.clone() - - lora_down = torch.nn.Linear(15, 4) - lora_up = torch.nn.Linear(4, 10) - - initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4) - - # Verify weight modification proportional to scale - weight_diff = original_weight - org_module.weight.data - - # Approximate check of scaling effect - if scale == 0.0: - torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6) - else: - # For non-zero scales, verify the magnitude of change is proportional to scale - assert weight_diff.abs().sum() > 0 - - # Do a second run with double the scale - org_module2 = torch.nn.Linear(15, 10) - org_module2.weight.data = original_weight.clone() - - lora_down2 = torch.nn.Linear(15, 4) - lora_up2 = torch.nn.Linear(4, 10) - - initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4) - - weight_diff2 = original_weight - org_module2.weight.data - - # The ratio of differences should be approximately 2 - # (allowing for numerical precision issues) - ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10) - assert 1.9 < ratio < 2.1 - - -def test_initialize_pissa_large_matrix_performance(): - # Test with a large matrix to ensure it works well - # This is particularly relevant for IPCA mode - - # Skip if running on CPU to avoid long test times - if not torch.cuda.is_available(): - pytest.skip("Skipping large matrix test on CPU") - - org_module = torch.nn.Linear(1000, 500) - org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1 - - lora_down = torch.nn.Linear(1000, 16) - lora_up = torch.nn.Linear(16, 500) - - # Test standard approach - try: - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16) - except Exception as e: - pytest.fail(f"Standard SVD failed on large matrix: {e}") - - # Test IPCA approach - lora_down_ipca = torch.nn.Linear(1000, 16) - lora_up_ipca = torch.nn.Linear(16, 500) - - try: - initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16) - except Exception as e: - pytest.fail(f"IPCA approach failed on large matrix: {e}") - - # Test IPCA with lowrank - lora_down_both = torch.nn.Linear(1000, 16) - lora_up_both = torch.nn.Linear(16, 500) - - try: - initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True) - except Exception as e: - pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}") - - -def test_initialize_pissa_requires_grad_preservation(): - # Test that requires_grad property is preserved - org_module = torch.nn.Linear(20, 10) - org_module.weight.requires_grad = False - - lora_down = torch.nn.Linear(20, 4) - lora_up = torch.nn.Linear(4, 10) - - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4) - - # Check requires_grad is preserved - assert not org_module.weight.requires_grad - - # Test with requires_grad=True - org_module2 = torch.nn.Linear(20, 10) - org_module2.weight.requires_grad = True - - initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4) - - # Check requires_grad is preserved - assert org_module2.weight.requires_grad diff --git a/tests/library/test_network_utils_pissa.py b/tests/library/test_network_utils_pissa.py new file mode 100644 index 00000000..93825cd2 --- /dev/null +++ b/tests/library/test_network_utils_pissa.py @@ -0,0 +1,569 @@ +import pytest +import torch +from torch import Tensor +from typing import Tuple + +from library.network_utils import convert_pissa_to_standard_lora, initialize_pissa + + +def generate_synthetic_weights(org_weight, seed=42): + generator = torch.manual_seed(seed) + + # Base random normal distribution + weights = torch.randn_like(org_weight) + + # Add structured variance to mimic real-world weight matrices + # Techniques to create more realistic weight distributions: + + # 1. Block-wise variation + block_size = max(1, org_weight.shape[0] // 4) + for i in range(0, org_weight.shape[0], block_size): + block_end = min(i + block_size, org_weight.shape[0]) + block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling + weights[i:block_end, :] *= 1 + block_variation + + # 2. Sparse connectivity simulation + sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity + weights *= sparsity_mask.float() + + # 3. Magnitude decay + magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1) + weights *= magnitude_decay + + # 4. Add structured noise + structural_noise = torch.randn_like(org_weight) * 0.1 + weights += structural_noise + + # Normalize to have similar statistical properties to trained weights + weights = (weights - weights.mean()) / weights.std() + + return weights + + +class TestPissa: + """Test suite for convert_pissa_to_standard_lora function.""" + + @pytest.fixture + def basic_matrices(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, int]: + """Create basic test matrices with known properties.""" + torch.manual_seed(42) + d_model, rank = 64, 8 + + # Create original matrices + orig_up = torch.randn(d_model, rank, dtype=torch.float32) + orig_down = torch.randn(rank, d_model, dtype=torch.float32) + + # Create trained matrices (slightly different) + noise_scale = 0.1 + trained_up = orig_up + noise_scale * torch.randn_like(orig_up) + trained_down = orig_down + noise_scale * torch.randn_like(orig_down) + + return trained_up, trained_down, orig_up, orig_down, rank + + @pytest.fixture + def small_matrices(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, int]: + """Create small matrices for easier debugging.""" + torch.manual_seed(123) + d_model, rank = 8, 2 + + orig_up = torch.randn(d_model, rank, dtype=torch.float32) + orig_down = torch.randn(rank, d_model, dtype=torch.float32) + trained_up = orig_up + 0.1 * torch.randn_like(orig_up) + trained_down = orig_down + 0.1 * torch.randn_like(orig_down) + + return trained_up, trained_down, orig_up, orig_down, rank + + def test_initialize_pissa_rank_constraints(self): + # Test with different rank values + org_module = torch.nn.Linear(20, 10) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + torch.nn.init.xavier_uniform_(org_module.weight) + torch.nn.init.zeros_(org_module.bias) + + # Test with rank less than min dimension + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Test with rank equal to min dimension + lora_down = torch.nn.Linear(20, 10) + lora_up = torch.nn.Linear(10, 10) + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10) + + def test_initialize_pissa_rank_limits(self): + # Test rank limits + org_module = torch.nn.Linear(10, 5) + + # Test minimum rank (should work) + lora_down_min = torch.nn.Linear(10, 1) + lora_up_min = torch.nn.Linear(1, 5) + initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1) + + # Test maximum rank (rank = min(input_dim, output_dim)) + max_rank = min(10, 5) + lora_down_max = torch.nn.Linear(10, max_rank) + lora_up_max = torch.nn.Linear(max_rank, 5) + initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank) + + def test_initialize_pissa_basic(self): + # Create a simple linear layer + org_module = torch.nn.Linear(10, 5) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + torch.nn.init.xavier_uniform_(org_module.weight) + torch.nn.init.zeros_(org_module.bias) + + # Create LoRA layers with matching shapes + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + # Store original weight for comparison + original_weight = org_module.weight.data.clone() + + # Call the initialization function + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) + + # Verify basic properties + assert lora_down.weight.data is not None + assert lora_up.weight.data is not None + assert org_module.weight.data is not None + + # Check that the weights have been modified + assert not torch.equal(original_weight, org_module.weight.data) + + def test_initialize_pissa_with_lowrank(self): + # Test with low-rank SVD option + org_module = torch.nn.Linear(50, 30) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(50, 5) + lora_up = torch.nn.Linear(5, 30) + + original_weight = org_module.weight.data.clone() + + # Call with low-rank SVD enabled + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True) + + # Verify weights are changed + assert not torch.equal(original_weight, org_module.weight.data) + + def test_initialize_pissa_custom_lowrank_params(self): + # Test with custom low-rank parameters + org_module = torch.nn.Linear(30, 20) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(30, 5) + lora_up = torch.nn.Linear(5, 20) + + # Test with custom q value and iterations + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_niter=6) + + # Check basic validity + assert lora_down.weight.data is not None + assert lora_up.weight.data is not None + + def test_initialize_pissa_device_handling(self): + # Test different device scenarios + devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] + + for device in devices: + # Create modules on specific device + org_module = torch.nn.Linear(10, 5).to(device) + lora_down = torch.nn.Linear(10, 2).to(device) + lora_up = torch.nn.Linear(2, 5).to(device) + + # Test initialization with explicit device + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device) + + # Verify modules are on the correct device + assert org_module.weight.data.device.type == device.type + assert lora_down.weight.data.device.type == device.type + assert lora_up.weight.data.device.type == device.type + + # Test with IPCA + if device.type == "cpu": # IPCA might be slow on CPU for large matrices + org_module_small = torch.nn.Linear(20, 10).to(device) + lora_down_small = torch.nn.Linear(20, 3).to(device) + lora_up_small = torch.nn.Linear(3, 10).to(device) + + initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device) + + assert org_module_small.weight.data.device.type == device.type + + def test_initialize_pissa_shape_mismatch(self): + # Test with shape mismatch to ensure warning is printed + org_module = torch.nn.Linear(20, 10) + + # Intentionally mismatched shapes to test warning mechanism + lora_down = torch.nn.Linear(20, 5) # Different shape + lora_up = torch.nn.Linear(3, 15) # Different shape + + with pytest.warns(UserWarning): + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + def test_initialize_pissa_scaling(self): + # Test different scaling factors + scales = [0.0, 0.1, 1.0] + + for scale in scales: + org_module = torch.nn.Linear(10, 5) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + original_weight = org_module.weight.data.clone() + + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2) + + # Check that the weight modification follows the scaling + weight_diff = original_weight - org_module.weight.data + expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data) + + torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4) + + def test_initialize_pissa_dtype(self): + # Test with different data types + dtypes = [torch.float16, torch.float32, torch.float64] + + for dtype in dtypes: + org_module = torch.nn.Linear(10, 5).to(dtype=dtype) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) + + # Verify output dtype matches input + assert org_module.weight.dtype == dtype + + def test_initialize_pissa_svd_properties(self): + # Verify SVD decomposition properties + org_module = torch.nn.Linear(20, 10) + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) + + org_module.weight.data = generate_synthetic_weights(org_module.weight) + original_weight = org_module.weight.data.clone() + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Reconstruct the weight + reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data) + + # Check reconstruction is close to original + torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4) + + def test_initialize_pissa_dtype_preservation(self): + # Test dtype preservation and conversion + dtypes = [torch.float16, torch.float32, torch.float64] + + for dtype in dtypes: + org_module = torch.nn.Linear(10, 5).to(dtype=dtype) + lora_down = torch.nn.Linear(10, 2).to(dtype=dtype) + lora_up = torch.nn.Linear(2, 5).to(dtype=dtype) + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) + + assert org_module.weight.dtype == dtype + assert lora_down.weight.dtype == dtype + assert lora_up.weight.dtype == dtype + + # Test with explicit dtype + if dtype != torch.float16: # Skip float16 for computational stability in SVD + org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32) + lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32) + lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32) + + initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype) + + # Original module should be converted to specified dtype + assert org_module2.weight.dtype == torch.float32 + + def test_initialize_pissa_numerical_stability(self): + # Test with numerically challenging scenarios + scenarios = [ + torch.randn(20, 10) * 1e-5, # Small values + torch.randn(20, 10) * 1e5, # Large values + torch.ones(20, 10), # Uniform values + ] + + for i, weight_matrix in enumerate(scenarios): + org_module = torch.nn.Linear(20, 10) + org_module.weight.data = weight_matrix + + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) + + try: + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Test IPCA as well + lora_down_ipca = torch.nn.Linear(20, 3) + lora_up_ipca = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3) + except Exception as e: + pytest.fail(f"Initialization failed for scenario ({i}): {e}") + + def test_initialize_pissa_scale_effects(self): + # Test effect of different scaling factors + org_module = torch.nn.Linear(15, 10) + original_weight = torch.randn_like(org_module.weight.data) + org_module.weight.data = original_weight.clone() + + # Try different scales + scales = [0.0, 0.01, 0.1, 1.0] + + for scale in scales: + # Reset to original weights + org_module.weight.data = original_weight.clone() + + lora_down = torch.nn.Linear(15, 4) + lora_up = torch.nn.Linear(4, 10) + + initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4) + + # Verify weight modification proportional to scale + weight_diff = original_weight - org_module.weight.data + + # Approximate check of scaling effect + if scale == 0.0: + torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6) + else: + # For non-zero scales, verify the magnitude of change is proportional to scale + assert weight_diff.abs().sum() > 0 + + # Do a second run with double the scale + org_module2 = torch.nn.Linear(15, 10) + org_module2.weight.data = original_weight.clone() + + lora_down2 = torch.nn.Linear(15, 4) + lora_up2 = torch.nn.Linear(4, 10) + + initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4) + + weight_diff2 = original_weight - org_module2.weight.data + + # The ratio of differences should be approximately 2 + # (allowing for numerical precision issues) + ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10) + assert 1.9 < ratio < 2.1 + + def test_initialize_pissa_large_matrix_performance(self): + # Test with a large matrix to ensure it works well + # This is particularly relevant for IPCA mode + + # Skip if running on CPU to avoid long test times + if not torch.cuda.is_available(): + pytest.skip("Skipping large matrix test on CPU") + + org_module = torch.nn.Linear(1000, 500) + org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1 + + lora_down = torch.nn.Linear(1000, 16) + lora_up = torch.nn.Linear(16, 500) + + # Test standard approach + try: + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16) + except Exception as e: + pytest.fail(f"Standard SVD failed on large matrix: {e}") + + # Test IPCA approach + lora_down_ipca = torch.nn.Linear(1000, 16) + lora_up_ipca = torch.nn.Linear(16, 500) + + try: + initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16) + except Exception as e: + pytest.fail(f"IPCA approach failed on large matrix: {e}") + + # Test IPCA with lowrank + lora_down_both = torch.nn.Linear(1000, 16) + lora_up_both = torch.nn.Linear(16, 500) + + try: + initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True) + except Exception as e: + pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}") + + def test_initialize_pissa_requires_grad_preservation(self): + # Test that requires_grad property is preserved + org_module = torch.nn.Linear(20, 10) + org_module.weight.requires_grad = False + + lora_down = torch.nn.Linear(20, 4) + lora_up = torch.nn.Linear(4, 10) + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4) + + # Check requires_grad is preserved + assert not org_module.weight.requires_grad + + # Test with requires_grad=True + org_module2 = torch.nn.Linear(20, 10) + org_module2.weight.requires_grad = True + + initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4) + + # Check requires_grad is preserved + assert org_module2.weight.requires_grad + + def test_basic_functionality(self, basic_matrices): + """Test that the function runs without errors and returns expected shapes.""" + trained_up, trained_down, orig_up, orig_down, rank = basic_matrices + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Check output types + assert isinstance(new_up, torch.Tensor) + assert isinstance(new_down, torch.Tensor) + + # Check shapes - should be compatible for matrix multiplication + d_model = trained_up.shape[0] + expected_rank = min(rank * 2, min(d_model, trained_down.shape[1])) + + assert new_up.shape == torch.Size([d_model, expected_rank]) + assert new_down.shape == (expected_rank, trained_down.shape[1]) + + def test_delta_preservation(self, basic_matrices): + """Test that the delta weight is preserved in the LoRA decomposition.""" + trained_up, trained_down, orig_up, orig_down, rank = basic_matrices + + # Calculate original delta + original_delta = (trained_up @ trained_down) - (orig_up @ orig_down) + + # Convert to LoRA + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Reconstruct delta from LoRA matrices + reconstructed_delta = new_up @ new_down + + # Check that reconstruction approximates original delta + # (Note: some information loss is expected due to rank reduction) + relative_error = torch.norm(original_delta - reconstructed_delta) / torch.norm(original_delta) + assert relative_error < 0.5 # Allow some approximation error + + def test_rank_handling(self, small_matrices): + """Test various rank scenarios.""" + trained_up, trained_down, orig_up, orig_down, base_rank = small_matrices + d_model = trained_up.shape[0] + + # Test with rank that would exceed matrix dimensions + large_rank = d_model + 5 + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, large_rank) + + # Should not exceed available singular values + max_possible_rank = min(d_model, trained_down.shape[1]) + assert new_up.shape[1] <= max_possible_rank + assert new_down.shape[0] <= max_possible_rank + + def test_zero_delta(self): + """Test behavior when trained and original matrices are identical.""" + torch.manual_seed(456) + d_model, rank = 16, 4 + + # Create identical matrices + orig_up = torch.randn(d_model, rank, dtype=torch.float32) + orig_down = torch.randn(rank, d_model, dtype=torch.float32) + trained_up = orig_up.clone() + trained_down = orig_down.clone() + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Reconstructed delta should be close to zero + reconstructed_delta = new_up @ new_down + assert torch.allclose(reconstructed_delta, torch.zeros_like(reconstructed_delta), atol=1e-6) + + def test_different_devices(self, basic_matrices): + """Test that the function handles different device placement correctly.""" + trained_up, trained_down, orig_up, orig_down, rank = basic_matrices + + # Test with CPU tensors + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Results should be on the same device as input + assert new_up.device == trained_up.device + assert new_down.device == trained_up.device + + def test_gradient_disabled(self, basic_matrices): + """Test that gradients are properly disabled.""" + trained_up, trained_down, orig_up, orig_down, rank = basic_matrices + + # Enable gradients on inputs + trained_up.requires_grad_(True) + trained_down.requires_grad_(True) + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Outputs should not require gradients due to torch.no_grad() + assert not new_up.requires_grad + assert not new_down.requires_grad + + def test_dtype_consistency(self, basic_matrices): + """Test that output dtypes are consistent.""" + trained_up, trained_down, orig_up, orig_down, rank = basic_matrices + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Should maintain float32 dtype + assert new_up.dtype == torch.float32 + assert new_down.dtype == torch.float32 + + def test_mathematical_properties(self, small_matrices): + """Test mathematical properties of the SVD decomposition.""" + trained_up, trained_down, orig_up, orig_down, rank = small_matrices + + # Calculate delta manually + delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # The decomposition should satisfy: new_up @ new_down ≈ low-rank approximation of delta_w + reconstructed = new_up @ new_down + + # Check that reconstruction has expected rank + actual_rank = torch.linalg.matrix_rank(reconstructed).item() + expected_max_rank = min(rank * 2, min(delta_w.shape)) + assert actual_rank <= expected_max_rank + + @pytest.mark.parametrize("rank", [1, 4, 8, 16]) + def test_different_ranks(self, rank): + """Test the function with different rank values.""" + torch.manual_seed(789) + d_model = 32 + + orig_up = torch.randn(d_model, rank, dtype=torch.float32) + orig_down = torch.randn(rank, d_model, dtype=torch.float32) + trained_up = orig_up + 0.1 * torch.randn_like(orig_up) + trained_down = orig_down + 0.1 * torch.randn_like(orig_down) + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # Should handle all rank values gracefully + assert new_up.shape[0] == d_model + assert new_down.shape[1] == d_model + assert new_up.shape[1] == new_down.shape[0] # Compatible for multiplication + + def test_edge_case_single_rank(self): + """Test with minimal rank (rank=1).""" + torch.manual_seed(101) + d_model, rank = 8, 1 + + orig_up = torch.randn(d_model, rank, dtype=torch.float32) + orig_down = torch.randn(rank, d_model, dtype=torch.float32) + trained_up = orig_up + 0.2 * torch.randn_like(orig_up) + trained_down = orig_down + 0.2 * torch.randn_like(orig_down) + + new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank) + + # With rank=1, output rank should be 2 (rank * 2) + expected_rank = min(2, min(d_model, d_model)) + assert new_up.shape[1] <= expected_rank + assert new_down.shape[0] <= expected_rank + + +if __name__ == "__main__": + # Run the tests + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_network_utils_urae.py b/tests/library/test_network_utils_urae.py index 581e13b8..7bbac8a8 100644 --- a/tests/library/test_network_utils_urae.py +++ b/tests/library/test_network_utils_urae.py @@ -4,109 +4,87 @@ import torch from library.network_utils import convert_urae_to_standard_lora -class TestConvertURAEToStandardLoRA: +class TestURAE: @pytest.fixture def sample_matrices(self): """Create sample matrices for testing""" # Original up matrix (4x2) - orig_up = torch.tensor([ - [1.0, 2.0], - [3.0, 4.0], - [5.0, 6.0], - [7.0, 8.0] - ]) - + orig_up = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]) + # Original down matrix (2x6) - orig_down = torch.tensor([ - [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], - [0.7, 0.8, 0.9, 1.0, 1.1, 1.2] - ]) - + orig_down = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]]) + # Trained up matrix (4x2) - same shape as orig_up but with changed values - trained_up = torch.tensor([ - [1.1, 2.1], - [3.1, 4.1], - [5.1, 6.1], - [7.1, 8.1] - ]) - + trained_up = torch.tensor([[1.1, 2.1], [3.1, 4.1], [5.1, 6.1], [7.1, 8.1]]) + # Trained down matrix (2x6) - same shape as orig_down but with changed values - trained_down = torch.tensor([ - [0.15, 0.25, 0.35, 0.45, 0.55, 0.65], - [0.75, 0.85, 0.95, 1.05, 1.15, 1.25] - ]) - - return { - 'orig_up': orig_up, - 'orig_down': orig_down, - 'trained_up': trained_up, - 'trained_down': trained_down - } + trained_down = torch.tensor([[0.15, 0.25, 0.35, 0.45, 0.55, 0.65], [0.75, 0.85, 0.95, 1.05, 1.15, 1.25]]) + + return {"orig_up": orig_up, "orig_down": orig_down, "trained_up": trained_up, "trained_down": trained_down} def test_basic_conversion(self, sample_matrices): """Test the basic functionality of convert_urae_to_standard_lora""" lora_up, lora_down, alpha = convert_urae_to_standard_lora( - sample_matrices['trained_up'], - sample_matrices['trained_down'], - sample_matrices['orig_up'], - sample_matrices['orig_down'] + sample_matrices["trained_up"], sample_matrices["trained_down"], sample_matrices["orig_up"], sample_matrices["orig_down"] ) - + # Check shapes - assert lora_up.shape[0] == sample_matrices['trained_up'].shape[0] # Same number of rows as trained_up - assert lora_up.shape[1] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up - assert lora_down.shape[0] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up - assert lora_down.shape[1] == sample_matrices['trained_down'].shape[1] # Same number of columns as trained_down - + assert lora_up.shape[0] == sample_matrices["trained_up"].shape[0] # Same number of rows as trained_up + assert lora_up.shape[1] == sample_matrices["trained_up"].shape[1] # Same rank as trained_up + assert lora_down.shape[0] == sample_matrices["trained_up"].shape[1] # Same rank as trained_up + assert lora_down.shape[1] == sample_matrices["trained_down"].shape[1] # Same number of columns as trained_down + # Check alpha is a reasonable value assert 0.1 <= alpha <= 1024.0 - + # Check that lora_up @ lora_down approximates the weight delta - delta = (sample_matrices['trained_up'] @ sample_matrices['trained_down']) - (sample_matrices['orig_up'] @ sample_matrices['orig_down']) - + delta = (sample_matrices["trained_up"] @ sample_matrices["trained_down"]) - ( + sample_matrices["orig_up"] @ sample_matrices["orig_down"] + ) + # The approximation should be close in Frobenius norm after scaling lora_effect = lora_up @ lora_down delta_norm = torch.norm(delta, p="fro").item() lora_norm = torch.norm(lora_effect, p="fro").item() - + # Either they are close, or the alpha scaling brings them close - scaled_lora_effect = (alpha / sample_matrices['trained_up'].shape[1]) * lora_effect + scaled_lora_effect = (alpha / sample_matrices["trained_up"].shape[1]) * lora_effect scaled_lora_norm = torch.norm(scaled_lora_effect, p="fro").item() - + # At least one of these should be true assert abs(delta_norm - lora_norm) < 1e-4 or abs(delta_norm - scaled_lora_norm) < 1e-4 def test_specified_rank(self, sample_matrices): """Test conversion with a specified rank""" new_rank = 1 # Lower than trained_up's rank of 2 - + lora_up, lora_down, alpha = convert_urae_to_standard_lora( - sample_matrices['trained_up'], - sample_matrices['trained_down'], - sample_matrices['orig_up'], - sample_matrices['orig_down'], - rank=new_rank + sample_matrices["trained_up"], + sample_matrices["trained_down"], + sample_matrices["orig_up"], + sample_matrices["orig_down"], + rank=new_rank, ) - + # Check that the new rank is used assert lora_up.shape[1] == new_rank assert lora_down.shape[0] == new_rank - + # Should still produce a reasonable alpha assert 0.1 <= alpha <= 1024.0 def test_with_initial_alpha(self, sample_matrices): """Test conversion with a specified initial alpha""" initial_alpha = 16.0 - + lora_up, lora_down, alpha = convert_urae_to_standard_lora( - sample_matrices['trained_up'], - sample_matrices['trained_down'], - sample_matrices['orig_up'], - sample_matrices['orig_down'], - initial_alpha=initial_alpha + sample_matrices["trained_up"], + sample_matrices["trained_down"], + sample_matrices["orig_up"], + sample_matrices["orig_down"], + initial_alpha=initial_alpha, ) - + # Alpha should be influenced by initial_alpha but may be adjusted # Since we're using same rank, should be reasonably close to initial_alpha assert 0.1 <= alpha <= 1024.0 @@ -116,30 +94,30 @@ class TestConvertURAEToStandardLoRA: def test_large_initial_alpha(self, sample_matrices): """Test conversion with a very large initial alpha that should be capped""" initial_alpha = 2000.0 # Larger than the 1024.0 cap - + lora_up, lora_down, alpha = convert_urae_to_standard_lora( - sample_matrices['trained_up'], - sample_matrices['trained_down'], - sample_matrices['orig_up'], - sample_matrices['orig_down'], - initial_alpha=initial_alpha + sample_matrices["trained_up"], + sample_matrices["trained_down"], + sample_matrices["orig_up"], + sample_matrices["orig_down"], + initial_alpha=initial_alpha, ) - + # Alpha should be capped at 1024.0 assert alpha <= 1024.0 def test_very_small_initial_alpha(self, sample_matrices): """Test conversion with a very small initial alpha that should be floored""" initial_alpha = 0.01 # Smaller than the 0.1 floor - + lora_up, lora_down, alpha = convert_urae_to_standard_lora( - sample_matrices['trained_up'], - sample_matrices['trained_down'], - sample_matrices['orig_up'], - sample_matrices['orig_down'], - initial_alpha=initial_alpha + sample_matrices["trained_up"], + sample_matrices["trained_down"], + sample_matrices["orig_up"], + sample_matrices["orig_down"], + initial_alpha=initial_alpha, ) - + # Alpha should be floored at 0.1 assert alpha >= 0.1 @@ -147,22 +125,22 @@ class TestConvertURAEToStandardLoRA: """Test conversion with both rank change and initial alpha""" initial_alpha = 16.0 new_rank = 1 # Half of original rank 2 - + lora_up, lora_down, alpha = convert_urae_to_standard_lora( - sample_matrices['trained_up'], - sample_matrices['trained_down'], - sample_matrices['orig_up'], - sample_matrices['orig_down'], + sample_matrices["trained_up"], + sample_matrices["trained_down"], + sample_matrices["orig_up"], + sample_matrices["orig_down"], initial_alpha=initial_alpha, - rank=new_rank + rank=new_rank, ) - + # Check shapes assert lora_up.shape[1] == new_rank assert lora_down.shape[0] == new_rank - + # Alpha should be adjusted for the rank change (approx halved in this case) - expected_alpha = initial_alpha * (new_rank / sample_matrices['trained_up'].shape[1]) + expected_alpha = initial_alpha * (new_rank / sample_matrices["trained_up"].shape[1]) # Allow some tolerance for adjustments from norm-based capping assert abs(alpha - expected_alpha) <= expected_alpha * 4.0 or alpha >= 0.1 @@ -170,47 +148,37 @@ class TestConvertURAEToStandardLoRA: """Test conversion when delta is zero""" # Create matrices where the delta will be zero dim_in, rank, dim_out = 4, 2, 6 - + # Create identical matrices for original and trained orig_up = torch.randn(dim_in, rank) orig_down = torch.randn(rank, dim_out) trained_up = orig_up.clone() trained_down = orig_down.clone() - - lora_up, lora_down, alpha = convert_urae_to_standard_lora( - trained_up, - trained_down, - orig_up, - orig_down - ) - + + lora_up, lora_down, alpha = convert_urae_to_standard_lora(trained_up, trained_down, orig_up, orig_down) + # Should still return matrices of correct shape assert lora_up.shape == (dim_in, rank) assert lora_down.shape == (rank, dim_out) - + # Alpha should be at least the minimum assert alpha >= 0.1 def test_large_dimensions(self): """Test with larger matrix dimensions""" dim_in, rank, dim_out = 100, 8, 200 - + orig_up = torch.randn(dim_in, rank) orig_down = torch.randn(rank, dim_out) trained_up = orig_up + 0.01 * torch.randn(dim_in, rank) # Small perturbation trained_down = orig_down + 0.01 * torch.randn(rank, dim_out) # Small perturbation - - lora_up, lora_down, alpha = convert_urae_to_standard_lora( - trained_up, - trained_down, - orig_up, - orig_down - ) - + + lora_up, lora_down, alpha = convert_urae_to_standard_lora(trained_up, trained_down, orig_up, orig_down) + # Check shapes assert lora_up.shape == (dim_in, rank) assert lora_down.shape == (rank, dim_out) - + # Should produce a reasonable alpha assert 0.1 <= alpha <= 1024.0 @@ -218,23 +186,17 @@ class TestConvertURAEToStandardLoRA: """Test when requested rank exceeds available singular values""" # Small matrices with limited rank dim_in, rank, dim_out = 3, 2, 3 - + orig_up = torch.randn(dim_in, rank) orig_down = torch.randn(rank, dim_out) trained_up = orig_up + 0.1 * torch.randn(dim_in, rank) trained_down = orig_down + 0.1 * torch.randn(rank, dim_out) - + # Request rank larger than possible too_large_rank = 10 - - lora_up, lora_down, alpha = convert_urae_to_standard_lora( - trained_up, - trained_down, - orig_up, - orig_down, - rank=too_large_rank - ) - + + lora_up, lora_down, alpha = convert_urae_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank=too_large_rank) + # Rank should be limited to min(dim_in, dim_out, S.size) max_possible_rank = min(dim_in, dim_out) assert lora_up.shape[1] <= max_possible_rank