Fix tests for PiSSA, fix lowrank SVD, Remove ICPA

This commit is contained in:
rockerBOO
2025-06-03 17:00:31 -04:00
parent faab3f0440
commit 5e35ea5d7d
4 changed files with 83 additions and 171 deletions

View File

@@ -11,11 +11,9 @@ from dataclasses import dataclass
class InitializeParams:
"""Parameters for initialization methods (PiSSA, URAE)"""
use_ipca: bool = False
use_lowrank: bool = False
lowrank_q: Optional[int] = None
lowrank_niter: int = 4
lowrank_seed: Optional[int] = None
def initialize_parse_opts(key: str) -> InitializeParams:
@@ -26,7 +24,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
- "pissa" -> Default PiSSA with lowrank=True, niter=4
- "pissa_niter_4" -> PiSSA with niter=4
- "pissa_lowrank_false" -> PiSSA without lowrank
- "pissa_ipca_true" -> PiSSA with IPCA
- "pissa_q_16" -> PiSSA with lowrank_q=16
- "pissa_seed_42" -> PiSSA with seed=42
- "urae_..." -> Same options but for URAE
@@ -50,14 +47,7 @@ def initialize_parse_opts(key: str) -> InitializeParams:
# Parse the remaining parts
i = 1
while i < len(parts):
if parts[i] == "ipca":
if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]:
params.use_ipca = parts[i + 1] == "true"
i += 2
else:
params.use_ipca = True
i += 1
elif parts[i] == "lowrank":
if parts[i] == "lowrank":
if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]:
params.use_lowrank = parts[i + 1] == "true"
i += 2
@@ -76,12 +66,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
i += 2
else:
i += 1
elif parts[i] == "seed":
if i + 1 < len(parts) and parts[i + 1].isdigit():
params.lowrank_seed = int(parts[i + 1])
i += 2
else:
i += 1
else:
# Skip unknown parameter
i += 1
@@ -188,11 +172,9 @@ def initialize_pissa(
rank: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
use_ipca: bool = False,
use_lowrank: bool = False,
lowrank_q: Optional[int] = None,
lowrank_niter: int = 4,
lowrank_seed: Optional[int] = None,
):
org_module_device = org_module.weight.device
org_module_weight_dtype = org_module.weight.data.dtype
@@ -205,56 +187,21 @@ def initialize_pissa(
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
with torch.no_grad():
if use_ipca:
# Use Incremental PCA for large matrices
ipca = IncrementalPCA(
n_components=rank,
batch_size=1024,
lowrank=use_lowrank,
lowrank_q=lowrank_q if lowrank_q is not None else 2 * rank,
lowrank_niter=lowrank_niter,
lowrank_seed=lowrank_seed,
)
ipca.fit(weight)
# Extract principal components and singular values
Vr = ipca.components_.T # [out_features, rank]
Sr = ipca.singular_values_ # [rank]
Sr /= rank
# We need to get Uhr from transforming an identity matrix
identity = torch.eye(weight.shape[1], device=weight.device)
with torch.autocast(device.type, dtype=torch.float64):
Uhr = ipca.transform(identity).T # [rank, in_features]
elif use_lowrank:
# Use low-rank SVD approximation which is faster
seed_enabled = lowrank_seed is not None
if use_lowrank:
q_value = lowrank_q if lowrank_q is not None else 2 * rank
with torch.random.fork_rng(enabled=seed_enabled):
if seed_enabled:
torch.manual_seed(lowrank_seed)
U, S, V = torch.svd_lowrank(weight, q=q_value, niter=lowrank_niter)
Vr = U[:, :rank] # First rank left singular vectors
Sr = S[:rank] # First rank singular values
Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=q_value, niter=lowrank_niter)
Sr /= rank
Uhr = V[:rank] # First rank right singular vectors
Uhr = Ur.t()
else:
# Standard SVD approach
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, :rank]
Sr = S[:rank]
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Vr = V[:, : rank]
Sr = S[: rank]
Sr /= rank
Uhr = Uh[:rank]
Uhr = Uh[: rank]
# Uhr may be in higher precision
with torch.autocast(device.type, dtype=Uhr.dtype):
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape

View File

@@ -1,34 +0,0 @@
import torch
def generate_synthetic_weights(org_weight, seed=42):
generator = torch.manual_seed(seed)
# Base random normal distribution
weights = torch.randn_like(org_weight)
# Add structured variance to mimic real-world weight matrices
# Techniques to create more realistic weight distributions:
# 1. Block-wise variation
block_size = max(1, org_weight.shape[0] // 4)
for i in range(0, org_weight.shape[0], block_size):
block_end = min(i + block_size, org_weight.shape[0])
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
weights[i:block_end, :] *= (1 + block_variation)
# 2. Sparse connectivity simulation
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
weights *= sparsity_mask.float()
# 3. Magnitude decay
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
weights *= magnitude_decay
# 4. Add structured noise
structural_noise = torch.randn_like(org_weight) * 0.1
weights += structural_noise
# Normalize to have similar statistical properties to trained weights
weights = (weights - weights.mean()) / weights.std()
return weights

View File

@@ -1,7 +1,40 @@
import torch
import pytest
from library.network_utils import initialize_pissa
from library.test_util import generate_synthetic_weights
def generate_synthetic_weights(org_weight, seed=42):
generator = torch.manual_seed(seed)
# Base random normal distribution
weights = torch.randn_like(org_weight)
# Add structured variance to mimic real-world weight matrices
# Techniques to create more realistic weight distributions:
# 1. Block-wise variation
block_size = max(1, org_weight.shape[0] // 4)
for i in range(0, org_weight.shape[0], block_size):
block_end = min(i + block_size, org_weight.shape[0])
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
weights[i:block_end, :] *= 1 + block_variation
# 2. Sparse connectivity simulation
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
weights *= sparsity_mask.float()
# 3. Magnitude decay
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
weights *= magnitude_decay
# 4. Add structured noise
structural_noise = torch.randn_like(org_weight) * 0.1
weights += structural_noise
# Normalize to have similar statistical properties to trained weights
weights = (weights - weights.mean()) / weights.std()
return weights
def test_initialize_pissa_rank_constraints():
@@ -66,27 +99,6 @@ def test_initialize_pissa_basic():
assert not torch.equal(original_weight, org_module.weight.data)
def test_initialize_pissa_with_ipca():
# Test with IncrementalPCA option
org_module = torch.nn.Linear(100, 50) # Larger dimensions to test IPCA
org_module.weight.data = generate_synthetic_weights(org_module.weight)
lora_down = torch.nn.Linear(100, 8)
lora_up = torch.nn.Linear(8, 50)
original_weight = org_module.weight.data.clone()
# Call with IPCA enabled
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=8, use_ipca=True)
# Verify weights are changed
assert not torch.equal(original_weight, org_module.weight.data)
# Check that LoRA matrices have appropriate shapes
assert lora_down.weight.shape == torch.Size([8, 100])
assert lora_up.weight.shape == torch.Size([50, 8])
def test_initialize_pissa_with_lowrank():
# Test with low-rank SVD option
org_module = torch.nn.Linear(50, 30)
@@ -104,48 +116,6 @@ def test_initialize_pissa_with_lowrank():
assert not torch.equal(original_weight, org_module.weight.data)
def test_initialize_pissa_with_lowrank_seed():
# Test reproducibility with seed
org_module = torch.nn.Linear(20, 10)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
# First run with seed
lora_down1 = torch.nn.Linear(20, 3)
lora_up1 = torch.nn.Linear(3, 10)
initialize_pissa(org_module, lora_down1, lora_up1, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42)
result1_down = lora_down1.weight.data.clone()
result1_up = lora_up1.weight.data.clone()
# Reset module
org_module.weight.data = generate_synthetic_weights(org_module.weight)
# Second run with same seed
lora_down2 = torch.nn.Linear(20, 3)
lora_up2 = torch.nn.Linear(3, 10)
initialize_pissa(org_module, lora_down2, lora_up2, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42)
# Results should be identical
torch.testing.assert_close(result1_down, lora_down2.weight.data)
torch.testing.assert_close(result1_up, lora_up2.weight.data)
def test_initialize_pissa_ipca_with_lowrank():
# Test IncrementalPCA with low-rank SVD enabled
org_module = torch.nn.Linear(200, 100) # Larger dimensions
org_module.weight.data = generate_synthetic_weights(org_module.weight)
lora_down = torch.nn.Linear(200, 10)
lora_up = torch.nn.Linear(10, 100)
# Call with both IPCA and low-rank enabled
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10, use_ipca=True, use_lowrank=True, lowrank_q=20)
# Check shapes of resulting matrices
assert lora_down.weight.shape == torch.Size([10, 200])
assert lora_up.weight.shape == torch.Size([100, 10])
def test_initialize_pissa_custom_lowrank_params():
# Test with custom low-rank parameters
org_module = torch.nn.Linear(30, 20)
@@ -186,7 +156,7 @@ def test_initialize_pissa_device_handling():
lora_down_small = torch.nn.Linear(20, 3).to(device)
lora_up_small = torch.nn.Linear(3, 10).to(device)
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device, use_ipca=True)
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device)
assert org_module_small.weight.data.device.type == device.type
@@ -283,8 +253,7 @@ def test_initialize_pissa_dtype_preservation():
initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype)
# Original module should be converted to specified dtype
assert org_module2.weight.dtype == dtype
assert org_module2.weight.dtype == torch.float32
def test_initialize_pissa_numerical_stability():
@@ -308,7 +277,7 @@ def test_initialize_pissa_numerical_stability():
# Test IPCA as well
lora_down_ipca = torch.nn.Linear(20, 3)
lora_up_ipca = torch.nn.Linear(3, 10)
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3, use_ipca=True)
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3)
except Exception as e:
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
@@ -357,6 +326,7 @@ def test_initialize_pissa_scale_effects():
ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10)
assert 1.9 < ratio < 2.1
def test_initialize_pissa_large_matrix_performance():
# Test with a large matrix to ensure it works well
# This is particularly relevant for IPCA mode
@@ -382,7 +352,7 @@ def test_initialize_pissa_large_matrix_performance():
lora_up_ipca = torch.nn.Linear(16, 500)
try:
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16, use_ipca=True)
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16)
except Exception as e:
pytest.fail(f"IPCA approach failed on large matrix: {e}")
@@ -391,7 +361,7 @@ def test_initialize_pissa_large_matrix_performance():
lora_up_both = torch.nn.Linear(16, 500)
try:
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_ipca=True, use_lowrank=True)
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True)
except Exception as e:
pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}")
@@ -417,5 +387,3 @@ def test_initialize_pissa_requires_grad_preservation():
# Check requires_grad is preserved
assert org_module2.weight.requires_grad

View File

@@ -2,9 +2,40 @@ import pytest
import torch
import torch.nn as nn
from networks.lora_flux import LoRAModule, LoRANetwork, create_network
from library.test_util import generate_synthetic_weights
from unittest.mock import MagicMock
def generate_synthetic_weights(org_weight, seed=42):
generator = torch.manual_seed(seed)
# Base random normal distribution
weights = torch.randn_like(org_weight)
# Add structured variance to mimic real-world weight matrices
# Techniques to create more realistic weight distributions:
# 1. Block-wise variation
block_size = max(1, org_weight.shape[0] // 4)
for i in range(0, org_weight.shape[0], block_size):
block_end = min(i + block_size, org_weight.shape[0])
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
weights[i:block_end, :] *= 1 + block_variation
# 2. Sparse connectivity simulation
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
weights *= sparsity_mask.float()
# 3. Magnitude decay
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
weights *= magnitude_decay
# 4. Add structured noise
structural_noise = torch.randn_like(org_weight) * 0.1
weights += structural_noise
# Normalize to have similar statistical properties to trained weights
weights = (weights - weights.mean()) / weights.std()
return weights
def test_basic_linear_module_initialization():
# Test basic Linear module initialization