diff --git a/library/network_utils.py b/library/network_utils.py index 423c7309..1b51cb44 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -121,51 +121,52 @@ def initialize_urae( # Move original weight to chosen device and use float32 for numerical stability weight = org_module.weight.data.to(device, dtype=torch.float32) - # Perform SVD decomposition (either directly or with IPCA for memory efficiency) - if use_ipca: - ipca = IncrementalPCA( - n_components=None, - batch_size=1024, - lowrank=use_lowrank, - lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), - lowrank_niter=lowrank_niter, - lowrank_seed=lowrank_seed, - ) - ipca.fit(weight) + with torch.autocast(device.type), torch.no_grad(): + # Perform SVD decomposition (either directly or with IPCA for memory efficiency) + if use_ipca: + ipca = IncrementalPCA( + n_components=None, + batch_size=1024, + lowrank=use_lowrank, + lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), + lowrank_niter=lowrank_niter, + lowrank_seed=lowrank_seed, + ) + ipca.fit(weight) - # Extract singular values and vectors, focusing on the minor components (smallest singular values) - S_full = ipca.singular_values_ - V_full = ipca.components_.T # Shape: [out_features, total_rank] + # Extract singular values and vectors, focusing on the minor components (smallest singular values) + S_full = ipca.singular_values_ + V_full = ipca.components_.T # Shape: [out_features, total_rank] - # Get identity matrix to transform for right singular vectors - identity = torch.eye(weight.shape[1], device=weight.device) - Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features] + # Get identity matrix to transform for right singular vectors + identity = torch.eye(weight.shape[1], device=weight.device) + Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features] - # Extract the last 'rank' components (the minor/smallest ones) - Sr = S_full[-rank:] - Vr = V_full[:, -rank:] - Uhr = Uhr_full[-rank:] + # Extract the last 'rank' components (the minor/smallest ones) + Sr = S_full[-rank:] + Vr = V_full[:, -rank:] + Uhr = Uhr_full[-rank:] - # Scale singular values - Sr = Sr / rank - else: - # Direct SVD approach - U, S, Vh = torch.linalg.svd(weight, full_matrices=False) + # Scale singular values + Sr = Sr / rank + else: + # Direct SVD approach + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) - # Extract the minor components (smallest singular values) - Sr = S[-rank:] - Vr = U[:, -rank:] - Uhr = Vh[-rank:] + # Extract the minor components (smallest singular values) + Sr = S[-rank:] + Vr = U[:, -rank:] + Uhr = Vh[-rank:] - # Scale singular values - Sr = Sr / rank + # Scale singular values + Sr = Sr / rank - # Create the low-rank adapter matrices by splitting the minor components - # Down matrix: scaled right singular vectors with singular values - down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr + # Create the low-rank adapter matrices by splitting the minor components + # Down matrix: scaled right singular vectors with singular values + down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr - # Up matrix: scaled left singular vectors with singular values - up_matrix = Vr @ torch.diag(torch.sqrt(Sr)) + # Up matrix: scaled left singular vectors with singular values + up_matrix = Vr @ torch.diag(torch.sqrt(Sr)) # Assign to LoRA modules lora_down.weight.data = down_matrix.to(device=device, dtype=dtype) @@ -223,7 +224,8 @@ def initialize_pissa( # We need to get Uhr from transforming an identity matrix identity = torch.eye(weight.shape[1], device=weight.device) - Uhr = ipca.transform(identity).T # [rank, in_features] + with torch.autocast(device.type, dtype=torch.float64): + Uhr = ipca.transform(identity).T # [rank, in_features] elif use_lowrank: # Use low-rank SVD approximation which is faster @@ -248,9 +250,11 @@ def initialize_pissa( Sr /= rank Uhr = Uh[:rank] - # Create down and up matrices - down = torch.diag(torch.sqrt(Sr)) @ Uhr - up = Vr @ torch.diag(torch.sqrt(Sr)) + # Uhr may be in higher precision + with torch.autocast(device.type, dtype=Uhr.dtype): + # Create down and up matrices + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up = Vr @ torch.diag(torch.sqrt(Sr)) # Get expected shapes expected_down_shape = lora_down.weight.shape @@ -272,19 +276,20 @@ def initialize_pissa( def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): - # Calculate ΔW = A'B' - AB - delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) + with torch.no_grad(): + # Calculate ΔW = A'B' - AB + delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) - # We need to create new low-rank matrices that represent this delta - U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) + # We need to create new low-rank matrices that represent this delta + U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) - # Take the top 2*r singular values (as suggested in the paper) - rank = rank * 2 - rank = min(rank, len(S)) # Make sure we don't exceed available singular values + # Take the top 2*r singular values (as suggested in the paper) + rank = rank * 2 + rank = min(rank, len(S)) # Make sure we don't exceed available singular values - # Create new LoRA matrices - new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) - new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] + # Create new LoRA matrices + new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) + new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] # These matrices can now be used as standard LoRA weights return new_up, new_down @@ -314,63 +319,64 @@ def convert_urae_to_standard_lora( lora_down: Standard LoRA down matrix alpha: Appropriate alpha value for the LoRA """ - # Calculate the weight delta - delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) + with torch.no_grad(): + # Calculate the weight delta + delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) - # Perform SVD on the delta - U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False) + # Perform SVD on the delta + U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False) - # If rank is not specified, use the same rank as the trained matrices - if rank is None: - rank = trained_up.shape[1] - else: - # Ensure we don't exceed available singular values - rank = min(rank, len(S)) + # If rank is not specified, use the same rank as the trained matrices + if rank is None: + rank = trained_up.shape[1] + else: + # Ensure we don't exceed available singular values + rank = min(rank, len(S)) - # Create standard LoRA matrices using top singular values - # This is now standard LoRA (using top values), not URAE (which used bottom values during training) - lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) - lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] + # Create standard LoRA matrices using top singular values + # This is now standard LoRA (using top values), not URAE (which used bottom values during training) + lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) + lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] - # Method 1: Preserve the Frobenius norm of the delta - original_effect: float = torch.norm(delta_w, p="fro").item() - unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item() + # Method 1: Preserve the Frobenius norm of the delta + original_effect: float = torch.norm(delta_w, p="fro").item() + unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item() - # The scaling factor in lora is (alpha/r), so: - # alpha/r × ||AB|| = ||delta_W|| - # alpha = r × ||delta_W|| / ||AB|| - if unscaled_lora_effect > 0: - norm_based_alpha = rank * (original_effect / unscaled_lora_effect) - else: - norm_based_alpha = 1.0 # Fallback + # The scaling factor in lora is (alpha/r), so: + # alpha/r × ||AB|| = ||delta_W|| + # alpha = r × ||delta_W|| / ||AB|| + if unscaled_lora_effect > 0: + norm_based_alpha = rank * (original_effect / unscaled_lora_effect) + else: + norm_based_alpha = 1.0 # Fallback - # Method 2: If initial_alpha is provided, adjust based on rank change - if initial_alpha is not None: - initial_rank = trained_up.shape[1] - # Scale alpha proportionally if rank changed - rank_adjusted_alpha = initial_alpha * (rank / initial_rank) - else: - rank_adjusted_alpha = None + # Method 2: If initial_alpha is provided, adjust based on rank change + if initial_alpha is not None: + initial_rank = trained_up.shape[1] + # Scale alpha proportionally if rank changed + rank_adjusted_alpha = initial_alpha * (rank / initial_rank) + else: + rank_adjusted_alpha = None - # Choose the appropriate alpha - if rank_adjusted_alpha is not None: - # Use the rank-adjusted alpha, but ensure it's not too different from norm-based - # Cap the difference to avoid extreme values - alpha = rank_adjusted_alpha - # Optional: Cap alpha to be within a reasonable range of norm_based_alpha - if norm_based_alpha > 0: - max_factor = 5.0 # Allow up to 5x difference - upper_bound = norm_based_alpha * max_factor - lower_bound = norm_based_alpha / max_factor - alpha = min(max(alpha, lower_bound), upper_bound) - else: - # Use norm-based alpha - alpha = norm_based_alpha + # Choose the appropriate alpha + if rank_adjusted_alpha is not None: + # Use the rank-adjusted alpha, but ensure it's not too different from norm-based + # Cap the difference to avoid extreme values + alpha = rank_adjusted_alpha + # Optional: Cap alpha to be within a reasonable range of norm_based_alpha + if norm_based_alpha > 0: + max_factor = 5.0 # Allow up to 5x difference + upper_bound = norm_based_alpha * max_factor + lower_bound = norm_based_alpha / max_factor + alpha = min(max(alpha, lower_bound), upper_bound) + else: + # Use norm-based alpha + alpha = norm_based_alpha - # Round to a clean value for better usability - alpha = round(alpha, 2) + # Round to a clean value for better usability + alpha = round(alpha, 2) - # Ensure alpha is positive and within reasonable bounds - alpha = max(0.1, min(alpha, 1024.0)) + # Ensure alpha is positive and within reasonable bounds + alpha = max(0.1, min(alpha, 1024.0)) return lora_up, lora_down, alpha diff --git a/tests/library/test_network_utils.py b/tests/library/test_network_utils.py index e2e8c9a7..634c9937 100644 --- a/tests/library/test_network_utils.py +++ b/tests/library/test_network_utils.py @@ -4,6 +4,41 @@ from library.network_utils import initialize_pissa from library.test_util import generate_synthetic_weights +def test_initialize_pissa_rank_constraints(): + # Test with different rank values + org_module = torch.nn.Linear(20, 10) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + torch.nn.init.xavier_uniform_(org_module.weight) + torch.nn.init.zeros_(org_module.bias) + + # Test with rank less than min dimension + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Test with rank equal to min dimension + lora_down = torch.nn.Linear(20, 10) + lora_up = torch.nn.Linear(10, 10) + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10) + + +def test_initialize_pissa_rank_limits(): + # Test rank limits + org_module = torch.nn.Linear(10, 5) + + # Test minimum rank (should work) + lora_down_min = torch.nn.Linear(10, 1) + lora_up_min = torch.nn.Linear(1, 5) + initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1) + + # Test maximum rank (rank = min(input_dim, output_dim)) + max_rank = min(10, 5) + lora_down_max = torch.nn.Linear(10, max_rank) + lora_up_max = torch.nn.Linear(max_rank, 5) + initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank) + + def test_initialize_pissa_basic(): # Create a simple linear layer org_module = torch.nn.Linear(10, 5) @@ -31,23 +66,129 @@ def test_initialize_pissa_basic(): assert not torch.equal(original_weight, org_module.weight.data) -def test_initialize_pissa_rank_constraints(): - # Test with different rank values +def test_initialize_pissa_with_ipca(): + # Test with IncrementalPCA option + org_module = torch.nn.Linear(100, 50) # Larger dimensions to test IPCA + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(100, 8) + lora_up = torch.nn.Linear(8, 50) + + original_weight = org_module.weight.data.clone() + + # Call with IPCA enabled + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=8, use_ipca=True) + + # Verify weights are changed + assert not torch.equal(original_weight, org_module.weight.data) + + # Check that LoRA matrices have appropriate shapes + assert lora_down.weight.shape == torch.Size([8, 100]) + assert lora_up.weight.shape == torch.Size([50, 8]) + + +def test_initialize_pissa_with_lowrank(): + # Test with low-rank SVD option + org_module = torch.nn.Linear(50, 30) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(50, 5) + lora_up = torch.nn.Linear(5, 30) + + original_weight = org_module.weight.data.clone() + + # Call with low-rank SVD enabled + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True) + + # Verify weights are changed + assert not torch.equal(original_weight, org_module.weight.data) + + +def test_initialize_pissa_with_lowrank_seed(): + # Test reproducibility with seed org_module = torch.nn.Linear(20, 10) org_module.weight.data = generate_synthetic_weights(org_module.weight) - torch.nn.init.xavier_uniform_(org_module.weight) - torch.nn.init.zeros_(org_module.bias) + # First run with seed + lora_down1 = torch.nn.Linear(20, 3) + lora_up1 = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down1, lora_up1, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42) - # Test with rank less than min dimension - lora_down = torch.nn.Linear(20, 3) - lora_up = torch.nn.Linear(3, 10) - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + result1_down = lora_down1.weight.data.clone() + result1_up = lora_up1.weight.data.clone() - # Test with rank equal to min dimension - lora_down = torch.nn.Linear(20, 10) - lora_up = torch.nn.Linear(10, 10) - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10) + # Reset module + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + # Second run with same seed + lora_down2 = torch.nn.Linear(20, 3) + lora_up2 = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down2, lora_up2, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42) + + # Results should be identical + torch.testing.assert_close(result1_down, lora_down2.weight.data) + torch.testing.assert_close(result1_up, lora_up2.weight.data) + + +def test_initialize_pissa_ipca_with_lowrank(): + # Test IncrementalPCA with low-rank SVD enabled + org_module = torch.nn.Linear(200, 100) # Larger dimensions + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(200, 10) + lora_up = torch.nn.Linear(10, 100) + + # Call with both IPCA and low-rank enabled + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10, use_ipca=True, use_lowrank=True, lowrank_q=20) + + # Check shapes of resulting matrices + assert lora_down.weight.shape == torch.Size([10, 200]) + assert lora_up.weight.shape == torch.Size([100, 10]) + + +def test_initialize_pissa_custom_lowrank_params(): + # Test with custom low-rank parameters + org_module = torch.nn.Linear(30, 20) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(30, 5) + lora_up = torch.nn.Linear(5, 20) + + # Test with custom q value and iterations + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_q=12, lowrank_niter=6) + + # Check basic validity + assert lora_down.weight.data is not None + assert lora_up.weight.data is not None + + +def test_initialize_pissa_device_handling(): + # Test different device scenarios + devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] + + for device in devices: + # Create modules on specific device + org_module = torch.nn.Linear(10, 5).to(device) + lora_down = torch.nn.Linear(10, 2).to(device) + lora_up = torch.nn.Linear(2, 5).to(device) + + # Test initialization with explicit device + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device) + + # Verify modules are on the correct device + assert org_module.weight.data.device.type == device.type + assert lora_down.weight.data.device.type == device.type + assert lora_up.weight.data.device.type == device.type + + # Test with IPCA + if device.type == "cpu": # IPCA might be slow on CPU for large matrices + org_module_small = torch.nn.Linear(20, 10).to(device) + lora_down_small = torch.nn.Linear(20, 3).to(device) + lora_up_small = torch.nn.Linear(3, 10).to(device) + + initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device, use_ipca=True) + + assert org_module_small.weight.data.device.type == device.type def test_initialize_pissa_shape_mismatch(): @@ -118,25 +259,6 @@ def test_initialize_pissa_svd_properties(): torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4) -def test_initialize_pissa_device_handling(): - # Test different device scenarios - devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] - - for device in devices: - # Create modules on specific device - org_module = torch.nn.Linear(10, 5).to(device) - lora_down = torch.nn.Linear(10, 2).to(device) - lora_up = torch.nn.Linear(2, 5).to(device) - - # Test initialization with explicit device - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device) - - # Verify modules are on the correct device - assert org_module.weight.data.device.type == device.type - assert lora_down.weight.data.device.type == device.type - assert lora_up.weight.data.device.type == device.type - - def test_initialize_pissa_dtype_preservation(): # Test dtype preservation and conversion dtypes = [torch.float16, torch.float32, torch.float64] @@ -152,59 +274,62 @@ def test_initialize_pissa_dtype_preservation(): assert lora_down.weight.dtype == dtype assert lora_up.weight.dtype == dtype + # Test with explicit dtype + if dtype != torch.float16: # Skip float16 for computational stability in SVD + org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32) + lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32) + lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32) -def test_initialize_pissa_rank_limits(): - # Test rank limits - org_module = torch.nn.Linear(10, 5) + initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype) - # Test minimum rank (should work) - lora_down_min = torch.nn.Linear(10, 1) - lora_up_min = torch.nn.Linear(1, 5) - initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1) + # Original module should be converted to specified dtype + assert org_module2.weight.dtype == dtype - # Test maximum rank (rank = min(input_dim, output_dim)) - max_rank = min(10, 5) - lora_down_max = torch.nn.Linear(10, max_rank) - lora_up_max = torch.nn.Linear(max_rank, 5) - initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank) def test_initialize_pissa_numerical_stability(): # Test with numerically challenging scenarios scenarios = [ - torch.randn(20, 10) * 1e-10, # Very small values - torch.randn(20, 10) * 1e10, # Very large values - torch.zeros(20, 10), # Zero matrix + torch.randn(20, 10) * 1e-5, # Small values + torch.randn(20, 10) * 1e5, # Large values + torch.ones(20, 10), # Uniform values ] for i, weight_matrix in enumerate(scenarios): org_module = torch.nn.Linear(20, 10) org_module.weight.data = weight_matrix - lora_down = torch.nn.Linear(10, 3) - lora_up = torch.nn.Linear(3, 20) + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) try: initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Test IPCA as well + lora_down_ipca = torch.nn.Linear(20, 3) + lora_up_ipca = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3, use_ipca=True) except Exception as e: pytest.fail(f"Initialization failed for scenario ({i}): {e}") def test_initialize_pissa_scale_effects(): - # Test different scaling factors - org_module = torch.nn.Linear(10, 5) - original_weight = org_module.weight.data.clone() + # Test effect of different scaling factors + org_module = torch.nn.Linear(15, 10) + original_weight = torch.randn_like(org_module.weight.data) + org_module.weight.data = original_weight.clone() - test_scales = [0.0, 0.1, 0.5, 1.0] + # Try different scales + scales = [0.0, 0.01, 0.1, 1.0] - for scale in test_scales: - # Reset module for each test + for scale in scales: + # Reset to original weights org_module.weight.data = original_weight.clone() - lora_down = torch.nn.Linear(10, 2) - lora_up = torch.nn.Linear(2, 5) + lora_down = torch.nn.Linear(15, 4) + lora_up = torch.nn.Linear(4, 10) - initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2) + initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4) # Verify weight modification proportional to scale weight_diff = original_weight - org_module.weight.data @@ -213,4 +338,84 @@ def test_initialize_pissa_scale_effects(): if scale == 0.0: torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6) else: - assert not torch.allclose(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6) + # For non-zero scales, verify the magnitude of change is proportional to scale + assert weight_diff.abs().sum() > 0 + + # Do a second run with double the scale + org_module2 = torch.nn.Linear(15, 10) + org_module2.weight.data = original_weight.clone() + + lora_down2 = torch.nn.Linear(15, 4) + lora_up2 = torch.nn.Linear(4, 10) + + initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4) + + weight_diff2 = original_weight - org_module2.weight.data + + # The ratio of differences should be approximately 2 + # (allowing for numerical precision issues) + ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10) + assert 1.9 < ratio < 2.1 + +def test_initialize_pissa_large_matrix_performance(): + # Test with a large matrix to ensure it works well + # This is particularly relevant for IPCA mode + + # Skip if running on CPU to avoid long test times + if not torch.cuda.is_available(): + pytest.skip("Skipping large matrix test on CPU") + + org_module = torch.nn.Linear(1000, 500) + org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1 + + lora_down = torch.nn.Linear(1000, 16) + lora_up = torch.nn.Linear(16, 500) + + # Test standard approach + try: + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16) + except Exception as e: + pytest.fail(f"Standard SVD failed on large matrix: {e}") + + # Test IPCA approach + lora_down_ipca = torch.nn.Linear(1000, 16) + lora_up_ipca = torch.nn.Linear(16, 500) + + try: + initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16, use_ipca=True) + except Exception as e: + pytest.fail(f"IPCA approach failed on large matrix: {e}") + + # Test IPCA with lowrank + lora_down_both = torch.nn.Linear(1000, 16) + lora_up_both = torch.nn.Linear(16, 500) + + try: + initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_ipca=True, use_lowrank=True) + except Exception as e: + pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}") + + +def test_initialize_pissa_requires_grad_preservation(): + # Test that requires_grad property is preserved + org_module = torch.nn.Linear(20, 10) + org_module.weight.requires_grad = False + + lora_down = torch.nn.Linear(20, 4) + lora_up = torch.nn.Linear(4, 10) + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4) + + # Check requires_grad is preserved + assert not org_module.weight.requires_grad + + # Test with requires_grad=True + org_module2 = torch.nn.Linear(20, 10) + org_module2.weight.requires_grad = True + + initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4) + + # Check requires_grad is preserved + assert org_module2.weight.requires_grad + +