diff --git a/library/network_utils.py b/library/network_utils.py index 1b51cb44..f58cee4e 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -11,11 +11,9 @@ from dataclasses import dataclass class InitializeParams: """Parameters for initialization methods (PiSSA, URAE)""" - use_ipca: bool = False use_lowrank: bool = False lowrank_q: Optional[int] = None lowrank_niter: int = 4 - lowrank_seed: Optional[int] = None def initialize_parse_opts(key: str) -> InitializeParams: @@ -26,7 +24,6 @@ def initialize_parse_opts(key: str) -> InitializeParams: - "pissa" -> Default PiSSA with lowrank=True, niter=4 - "pissa_niter_4" -> PiSSA with niter=4 - "pissa_lowrank_false" -> PiSSA without lowrank - - "pissa_ipca_true" -> PiSSA with IPCA - "pissa_q_16" -> PiSSA with lowrank_q=16 - "pissa_seed_42" -> PiSSA with seed=42 - "urae_..." -> Same options but for URAE @@ -50,14 +47,7 @@ def initialize_parse_opts(key: str) -> InitializeParams: # Parse the remaining parts i = 1 while i < len(parts): - if parts[i] == "ipca": - if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]: - params.use_ipca = parts[i + 1] == "true" - i += 2 - else: - params.use_ipca = True - i += 1 - elif parts[i] == "lowrank": + if parts[i] == "lowrank": if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]: params.use_lowrank = parts[i + 1] == "true" i += 2 @@ -76,12 +66,6 @@ def initialize_parse_opts(key: str) -> InitializeParams: i += 2 else: i += 1 - elif parts[i] == "seed": - if i + 1 < len(parts) and parts[i + 1].isdigit(): - params.lowrank_seed = int(parts[i + 1]) - i += 2 - else: - i += 1 else: # Skip unknown parameter i += 1 @@ -188,11 +172,9 @@ def initialize_pissa( rank: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - use_ipca: bool = False, use_lowrank: bool = False, lowrank_q: Optional[int] = None, lowrank_niter: int = 4, - lowrank_seed: Optional[int] = None, ): org_module_device = org_module.weight.device org_module_weight_dtype = org_module.weight.data.dtype @@ -205,56 +187,21 @@ def initialize_pissa( weight = org_module.weight.data.clone().to(device, dtype=torch.float32) with torch.no_grad(): - if use_ipca: - # Use Incremental PCA for large matrices - ipca = IncrementalPCA( - n_components=rank, - batch_size=1024, - lowrank=use_lowrank, - lowrank_q=lowrank_q if lowrank_q is not None else 2 * rank, - lowrank_niter=lowrank_niter, - lowrank_seed=lowrank_seed, - ) - ipca.fit(weight) - - # Extract principal components and singular values - Vr = ipca.components_.T # [out_features, rank] - Sr = ipca.singular_values_ # [rank] - Sr /= rank - - # We need to get Uhr from transforming an identity matrix - identity = torch.eye(weight.shape[1], device=weight.device) - with torch.autocast(device.type, dtype=torch.float64): - Uhr = ipca.transform(identity).T # [rank, in_features] - - elif use_lowrank: - # Use low-rank SVD approximation which is faster - seed_enabled = lowrank_seed is not None + if use_lowrank: q_value = lowrank_q if lowrank_q is not None else 2 * rank - - with torch.random.fork_rng(enabled=seed_enabled): - if seed_enabled: - torch.manual_seed(lowrank_seed) - U, S, V = torch.svd_lowrank(weight, q=q_value, niter=lowrank_niter) - - Vr = U[:, :rank] # First rank left singular vectors - Sr = S[:rank] # First rank singular values + Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=q_value, niter=lowrank_niter) Sr /= rank - Uhr = V[:rank] # First rank right singular vectors - + Uhr = Ur.t() else: - # Standard SVD approach - V, S, Uh = torch.linalg.svd(weight, full_matrices=False) - Vr = V[:, :rank] - Sr = S[:rank] + # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, + V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) + Vr = V[:, : rank] + Sr = S[: rank] Sr /= rank - Uhr = Uh[:rank] + Uhr = Uh[: rank] - # Uhr may be in higher precision - with torch.autocast(device.type, dtype=Uhr.dtype): - # Create down and up matrices - down = torch.diag(torch.sqrt(Sr)) @ Uhr - up = Vr @ torch.diag(torch.sqrt(Sr)) + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up = Vr @ torch.diag(torch.sqrt(Sr)) # Get expected shapes expected_down_shape = lora_down.weight.shape diff --git a/library/test_util.py b/library/test_util.py deleted file mode 100644 index 3f866455..00000000 --- a/library/test_util.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch - -def generate_synthetic_weights(org_weight, seed=42): - generator = torch.manual_seed(seed) - - # Base random normal distribution - weights = torch.randn_like(org_weight) - - # Add structured variance to mimic real-world weight matrices - # Techniques to create more realistic weight distributions: - - # 1. Block-wise variation - block_size = max(1, org_weight.shape[0] // 4) - for i in range(0, org_weight.shape[0], block_size): - block_end = min(i + block_size, org_weight.shape[0]) - block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling - weights[i:block_end, :] *= (1 + block_variation) - - # 2. Sparse connectivity simulation - sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity - weights *= sparsity_mask.float() - - # 3. Magnitude decay - magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1) - weights *= magnitude_decay - - # 4. Add structured noise - structural_noise = torch.randn_like(org_weight) * 0.1 - weights += structural_noise - - # Normalize to have similar statistical properties to trained weights - weights = (weights - weights.mean()) / weights.std() - - return weights diff --git a/tests/library/test_network_utils.py b/tests/library/test_network_utils.py index 634c9937..363f2b85 100644 --- a/tests/library/test_network_utils.py +++ b/tests/library/test_network_utils.py @@ -1,7 +1,40 @@ import torch import pytest from library.network_utils import initialize_pissa -from library.test_util import generate_synthetic_weights + + +def generate_synthetic_weights(org_weight, seed=42): + generator = torch.manual_seed(seed) + + # Base random normal distribution + weights = torch.randn_like(org_weight) + + # Add structured variance to mimic real-world weight matrices + # Techniques to create more realistic weight distributions: + + # 1. Block-wise variation + block_size = max(1, org_weight.shape[0] // 4) + for i in range(0, org_weight.shape[0], block_size): + block_end = min(i + block_size, org_weight.shape[0]) + block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling + weights[i:block_end, :] *= 1 + block_variation + + # 2. Sparse connectivity simulation + sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity + weights *= sparsity_mask.float() + + # 3. Magnitude decay + magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1) + weights *= magnitude_decay + + # 4. Add structured noise + structural_noise = torch.randn_like(org_weight) * 0.1 + weights += structural_noise + + # Normalize to have similar statistical properties to trained weights + weights = (weights - weights.mean()) / weights.std() + + return weights def test_initialize_pissa_rank_constraints(): @@ -66,27 +99,6 @@ def test_initialize_pissa_basic(): assert not torch.equal(original_weight, org_module.weight.data) -def test_initialize_pissa_with_ipca(): - # Test with IncrementalPCA option - org_module = torch.nn.Linear(100, 50) # Larger dimensions to test IPCA - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - lora_down = torch.nn.Linear(100, 8) - lora_up = torch.nn.Linear(8, 50) - - original_weight = org_module.weight.data.clone() - - # Call with IPCA enabled - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=8, use_ipca=True) - - # Verify weights are changed - assert not torch.equal(original_weight, org_module.weight.data) - - # Check that LoRA matrices have appropriate shapes - assert lora_down.weight.shape == torch.Size([8, 100]) - assert lora_up.weight.shape == torch.Size([50, 8]) - - def test_initialize_pissa_with_lowrank(): # Test with low-rank SVD option org_module = torch.nn.Linear(50, 30) @@ -104,48 +116,6 @@ def test_initialize_pissa_with_lowrank(): assert not torch.equal(original_weight, org_module.weight.data) -def test_initialize_pissa_with_lowrank_seed(): - # Test reproducibility with seed - org_module = torch.nn.Linear(20, 10) - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - # First run with seed - lora_down1 = torch.nn.Linear(20, 3) - lora_up1 = torch.nn.Linear(3, 10) - initialize_pissa(org_module, lora_down1, lora_up1, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42) - - result1_down = lora_down1.weight.data.clone() - result1_up = lora_up1.weight.data.clone() - - # Reset module - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - # Second run with same seed - lora_down2 = torch.nn.Linear(20, 3) - lora_up2 = torch.nn.Linear(3, 10) - initialize_pissa(org_module, lora_down2, lora_up2, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42) - - # Results should be identical - torch.testing.assert_close(result1_down, lora_down2.weight.data) - torch.testing.assert_close(result1_up, lora_up2.weight.data) - - -def test_initialize_pissa_ipca_with_lowrank(): - # Test IncrementalPCA with low-rank SVD enabled - org_module = torch.nn.Linear(200, 100) # Larger dimensions - org_module.weight.data = generate_synthetic_weights(org_module.weight) - - lora_down = torch.nn.Linear(200, 10) - lora_up = torch.nn.Linear(10, 100) - - # Call with both IPCA and low-rank enabled - initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10, use_ipca=True, use_lowrank=True, lowrank_q=20) - - # Check shapes of resulting matrices - assert lora_down.weight.shape == torch.Size([10, 200]) - assert lora_up.weight.shape == torch.Size([100, 10]) - - def test_initialize_pissa_custom_lowrank_params(): # Test with custom low-rank parameters org_module = torch.nn.Linear(30, 20) @@ -186,7 +156,7 @@ def test_initialize_pissa_device_handling(): lora_down_small = torch.nn.Linear(20, 3).to(device) lora_up_small = torch.nn.Linear(3, 10).to(device) - initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device, use_ipca=True) + initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device) assert org_module_small.weight.data.device.type == device.type @@ -283,8 +253,7 @@ def test_initialize_pissa_dtype_preservation(): initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype) # Original module should be converted to specified dtype - assert org_module2.weight.dtype == dtype - + assert org_module2.weight.dtype == torch.float32 def test_initialize_pissa_numerical_stability(): @@ -308,7 +277,7 @@ def test_initialize_pissa_numerical_stability(): # Test IPCA as well lora_down_ipca = torch.nn.Linear(20, 3) lora_up_ipca = torch.nn.Linear(3, 10) - initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3, use_ipca=True) + initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3) except Exception as e: pytest.fail(f"Initialization failed for scenario ({i}): {e}") @@ -357,6 +326,7 @@ def test_initialize_pissa_scale_effects(): ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10) assert 1.9 < ratio < 2.1 + def test_initialize_pissa_large_matrix_performance(): # Test with a large matrix to ensure it works well # This is particularly relevant for IPCA mode @@ -382,7 +352,7 @@ def test_initialize_pissa_large_matrix_performance(): lora_up_ipca = torch.nn.Linear(16, 500) try: - initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16, use_ipca=True) + initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16) except Exception as e: pytest.fail(f"IPCA approach failed on large matrix: {e}") @@ -391,7 +361,7 @@ def test_initialize_pissa_large_matrix_performance(): lora_up_both = torch.nn.Linear(16, 500) try: - initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_ipca=True, use_lowrank=True) + initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True) except Exception as e: pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}") @@ -417,5 +387,3 @@ def test_initialize_pissa_requires_grad_preservation(): # Check requires_grad is preserved assert org_module2.weight.requires_grad - - diff --git a/tests/networks/test_lora_flux.py b/tests/networks/test_lora_flux.py index d5a516ce..e526d4e4 100644 --- a/tests/networks/test_lora_flux.py +++ b/tests/networks/test_lora_flux.py @@ -2,9 +2,40 @@ import pytest import torch import torch.nn as nn from networks.lora_flux import LoRAModule, LoRANetwork, create_network -from library.test_util import generate_synthetic_weights from unittest.mock import MagicMock +def generate_synthetic_weights(org_weight, seed=42): + generator = torch.manual_seed(seed) + + # Base random normal distribution + weights = torch.randn_like(org_weight) + + # Add structured variance to mimic real-world weight matrices + # Techniques to create more realistic weight distributions: + + # 1. Block-wise variation + block_size = max(1, org_weight.shape[0] // 4) + for i in range(0, org_weight.shape[0], block_size): + block_end = min(i + block_size, org_weight.shape[0]) + block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling + weights[i:block_end, :] *= 1 + block_variation + + # 2. Sparse connectivity simulation + sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity + weights *= sparsity_mask.float() + + # 3. Magnitude decay + magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1) + weights *= magnitude_decay + + # 4. Add structured noise + structural_noise = torch.randn_like(org_weight) * 0.1 + weights += structural_noise + + # Normalize to have similar statistical properties to trained weights + weights = (weights - weights.mean()) / weights.std() + + return weights def test_basic_linear_module_initialization(): # Test basic Linear module initialization