mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Fix lowrank PISSA. Add more tests
This commit is contained in:
@@ -471,7 +471,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
args, noise_scheduler, latents, noise, device, dtype, num_timesteps=1000
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
@@ -12,7 +12,7 @@ class InitializeParams:
|
||||
"""Parameters for initialization methods (PiSSA, URAE)"""
|
||||
|
||||
use_lowrank: bool = False
|
||||
lowrank_q: Optional[int] = None
|
||||
# lowrank_q: Optional[int] = None
|
||||
lowrank_niter: int = 4
|
||||
|
||||
|
||||
@@ -24,8 +24,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
- "pissa" -> Default PiSSA with lowrank=True, niter=4
|
||||
- "pissa_niter_4" -> PiSSA with niter=4
|
||||
- "pissa_lowrank_false" -> PiSSA without lowrank
|
||||
- "pissa_q_16" -> PiSSA with lowrank_q=16
|
||||
- "pissa_seed_42" -> PiSSA with seed=42
|
||||
- "urae_..." -> Same options but for URAE
|
||||
|
||||
Args:
|
||||
@@ -57,12 +55,7 @@ def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
elif parts[i] == "niter":
|
||||
if i + 1 < len(parts) and parts[i + 1].isdigit():
|
||||
params.lowrank_niter = int(parts[i + 1])
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
elif parts[i] == "q":
|
||||
if i + 1 < len(parts) and parts[i + 1].isdigit():
|
||||
params.lowrank_q = int(parts[i + 1])
|
||||
params.use_lowrank = True
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
@@ -173,7 +166,7 @@ def initialize_pissa(
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
use_lowrank: bool = False,
|
||||
lowrank_q: Optional[int] = None,
|
||||
# lowrank_q: Optional[int] = None,
|
||||
lowrank_niter: int = 4,
|
||||
):
|
||||
org_module_device = org_module.weight.device
|
||||
@@ -188,17 +181,17 @@ def initialize_pissa(
|
||||
|
||||
with torch.no_grad():
|
||||
if use_lowrank:
|
||||
q_value = lowrank_q if lowrank_q is not None else 2 * rank
|
||||
Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=q_value, niter=lowrank_niter)
|
||||
# q_value = lowrank_q if lowrank_q is not None else 2 * rank
|
||||
Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=rank, niter=lowrank_niter)
|
||||
Sr /= rank
|
||||
Uhr = Ur.t()
|
||||
else:
|
||||
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
||||
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
|
||||
Vr = V[:, : rank]
|
||||
Sr = S[: rank]
|
||||
Vr = V[:, :rank]
|
||||
Sr = S[:rank]
|
||||
Sr /= rank
|
||||
Uhr = Uh[: rank]
|
||||
Uhr = Uh[:rank]
|
||||
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
@@ -222,13 +215,15 @@ def initialize_pissa(
|
||||
org_module.weight.requires_grad = org_module_requires_grad
|
||||
|
||||
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
def convert_pissa_to_standard_lora(
|
||||
trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int
|
||||
):
|
||||
with torch.no_grad():
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
U, S, V = torch.linalg.svd(delta_w.to(trained_up.device, dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
|
||||
@@ -2,388 +2,3 @@ import torch
|
||||
import pytest
|
||||
from library.network_utils import initialize_pissa
|
||||
|
||||
|
||||
def generate_synthetic_weights(org_weight, seed=42):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Base random normal distribution
|
||||
weights = torch.randn_like(org_weight)
|
||||
|
||||
# Add structured variance to mimic real-world weight matrices
|
||||
# Techniques to create more realistic weight distributions:
|
||||
|
||||
# 1. Block-wise variation
|
||||
block_size = max(1, org_weight.shape[0] // 4)
|
||||
for i in range(0, org_weight.shape[0], block_size):
|
||||
block_end = min(i + block_size, org_weight.shape[0])
|
||||
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
|
||||
weights[i:block_end, :] *= 1 + block_variation
|
||||
|
||||
# 2. Sparse connectivity simulation
|
||||
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
|
||||
weights *= sparsity_mask.float()
|
||||
|
||||
# 3. Magnitude decay
|
||||
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
|
||||
weights *= magnitude_decay
|
||||
|
||||
# 4. Add structured noise
|
||||
structural_noise = torch.randn_like(org_weight) * 0.1
|
||||
weights += structural_noise
|
||||
|
||||
# Normalize to have similar statistical properties to trained weights
|
||||
weights = (weights - weights.mean()) / weights.std()
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_constraints():
|
||||
# Test with different rank values
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Test with rank less than min dimension
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test with rank equal to min dimension
|
||||
lora_down = torch.nn.Linear(20, 10)
|
||||
lora_up = torch.nn.Linear(10, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_limits():
|
||||
# Test rank limits
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
|
||||
# Test minimum rank (should work)
|
||||
lora_down_min = torch.nn.Linear(10, 1)
|
||||
lora_up_min = torch.nn.Linear(1, 5)
|
||||
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
||||
|
||||
# Test maximum rank (rank = min(input_dim, output_dim))
|
||||
max_rank = min(10, 5)
|
||||
lora_down_max = torch.nn.Linear(10, max_rank)
|
||||
lora_up_max = torch.nn.Linear(max_rank, 5)
|
||||
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
||||
|
||||
|
||||
def test_initialize_pissa_basic():
|
||||
# Create a simple linear layer
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Create LoRA layers with matching shapes
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
# Store original weight for comparison
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call the initialization function
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify basic properties
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
assert org_module.weight.data is not None
|
||||
|
||||
# Check that the weights have been modified
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_with_lowrank():
|
||||
# Test with low-rank SVD option
|
||||
org_module = torch.nn.Linear(50, 30)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(50, 5)
|
||||
lora_up = torch.nn.Linear(5, 30)
|
||||
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call with low-rank SVD enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True)
|
||||
|
||||
# Verify weights are changed
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_custom_lowrank_params():
|
||||
# Test with custom low-rank parameters
|
||||
org_module = torch.nn.Linear(30, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(30, 5)
|
||||
lora_up = torch.nn.Linear(5, 20)
|
||||
|
||||
# Test with custom q value and iterations
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_q=12, lowrank_niter=6)
|
||||
|
||||
# Check basic validity
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
|
||||
|
||||
def test_initialize_pissa_device_handling():
|
||||
# Test different device scenarios
|
||||
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
||||
|
||||
for device in devices:
|
||||
# Create modules on specific device
|
||||
org_module = torch.nn.Linear(10, 5).to(device)
|
||||
lora_down = torch.nn.Linear(10, 2).to(device)
|
||||
lora_up = torch.nn.Linear(2, 5).to(device)
|
||||
|
||||
# Test initialization with explicit device
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
||||
|
||||
# Verify modules are on the correct device
|
||||
assert org_module.weight.data.device.type == device.type
|
||||
assert lora_down.weight.data.device.type == device.type
|
||||
assert lora_up.weight.data.device.type == device.type
|
||||
|
||||
# Test with IPCA
|
||||
if device.type == "cpu": # IPCA might be slow on CPU for large matrices
|
||||
org_module_small = torch.nn.Linear(20, 10).to(device)
|
||||
lora_down_small = torch.nn.Linear(20, 3).to(device)
|
||||
lora_up_small = torch.nn.Linear(3, 10).to(device)
|
||||
|
||||
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device)
|
||||
|
||||
assert org_module_small.weight.data.device.type == device.type
|
||||
|
||||
|
||||
def test_initialize_pissa_shape_mismatch():
|
||||
# Test with shape mismatch to ensure warning is printed
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
|
||||
# Intentionally mismatched shapes to test warning mechanism
|
||||
lora_down = torch.nn.Linear(20, 5) # Different shape
|
||||
lora_up = torch.nn.Linear(3, 15) # Different shape
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
|
||||
def test_initialize_pissa_scaling():
|
||||
# Test different scaling factors
|
||||
scales = [0.0, 0.1, 1.0]
|
||||
|
||||
for scale in scales:
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
|
||||
# Check that the weight modification follows the scaling
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype():
|
||||
# Test with different data types
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify output dtype matches input
|
||||
assert org_module.weight.dtype == dtype
|
||||
|
||||
|
||||
def test_initialize_pissa_svd_properties():
|
||||
# Verify SVD decomposition properties
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Reconstruct the weight
|
||||
reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
# Check reconstruction is close to original
|
||||
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype_preservation():
|
||||
# Test dtype preservation and conversion
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
lora_down = torch.nn.Linear(10, 2).to(dtype=dtype)
|
||||
lora_up = torch.nn.Linear(2, 5).to(dtype=dtype)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
assert org_module.weight.dtype == dtype
|
||||
assert lora_down.weight.dtype == dtype
|
||||
assert lora_up.weight.dtype == dtype
|
||||
|
||||
# Test with explicit dtype
|
||||
if dtype != torch.float16: # Skip float16 for computational stability in SVD
|
||||
org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32)
|
||||
lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32)
|
||||
lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32)
|
||||
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype)
|
||||
|
||||
# Original module should be converted to specified dtype
|
||||
assert org_module2.weight.dtype == torch.float32
|
||||
|
||||
|
||||
def test_initialize_pissa_numerical_stability():
|
||||
# Test with numerically challenging scenarios
|
||||
scenarios = [
|
||||
torch.randn(20, 10) * 1e-5, # Small values
|
||||
torch.randn(20, 10) * 1e5, # Large values
|
||||
torch.ones(20, 10), # Uniform values
|
||||
]
|
||||
|
||||
for i, weight_matrix in enumerate(scenarios):
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = weight_matrix
|
||||
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test IPCA as well
|
||||
lora_down_ipca = torch.nn.Linear(20, 3)
|
||||
lora_up_ipca = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
||||
|
||||
|
||||
def test_initialize_pissa_scale_effects():
|
||||
# Test effect of different scaling factors
|
||||
org_module = torch.nn.Linear(15, 10)
|
||||
original_weight = torch.randn_like(org_module.weight.data)
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
# Try different scales
|
||||
scales = [0.0, 0.01, 0.1, 1.0]
|
||||
|
||||
for scale in scales:
|
||||
# Reset to original weights
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(15, 4)
|
||||
lora_up = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4)
|
||||
|
||||
# Verify weight modification proportional to scale
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
|
||||
# Approximate check of scaling effect
|
||||
if scale == 0.0:
|
||||
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
else:
|
||||
# For non-zero scales, verify the magnitude of change is proportional to scale
|
||||
assert weight_diff.abs().sum() > 0
|
||||
|
||||
# Do a second run with double the scale
|
||||
org_module2 = torch.nn.Linear(15, 10)
|
||||
org_module2.weight.data = original_weight.clone()
|
||||
|
||||
lora_down2 = torch.nn.Linear(15, 4)
|
||||
lora_up2 = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4)
|
||||
|
||||
weight_diff2 = original_weight - org_module2.weight.data
|
||||
|
||||
# The ratio of differences should be approximately 2
|
||||
# (allowing for numerical precision issues)
|
||||
ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10)
|
||||
assert 1.9 < ratio < 2.1
|
||||
|
||||
|
||||
def test_initialize_pissa_large_matrix_performance():
|
||||
# Test with a large matrix to ensure it works well
|
||||
# This is particularly relevant for IPCA mode
|
||||
|
||||
# Skip if running on CPU to avoid long test times
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Skipping large matrix test on CPU")
|
||||
|
||||
org_module = torch.nn.Linear(1000, 500)
|
||||
org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1
|
||||
|
||||
lora_down = torch.nn.Linear(1000, 16)
|
||||
lora_up = torch.nn.Linear(16, 500)
|
||||
|
||||
# Test standard approach
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Standard SVD failed on large matrix: {e}")
|
||||
|
||||
# Test IPCA approach
|
||||
lora_down_ipca = torch.nn.Linear(1000, 16)
|
||||
lora_up_ipca = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16)
|
||||
except Exception as e:
|
||||
pytest.fail(f"IPCA approach failed on large matrix: {e}")
|
||||
|
||||
# Test IPCA with lowrank
|
||||
lora_down_both = torch.nn.Linear(1000, 16)
|
||||
lora_up_both = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}")
|
||||
|
||||
|
||||
def test_initialize_pissa_requires_grad_preservation():
|
||||
# Test that requires_grad property is preserved
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.requires_grad = False
|
||||
|
||||
lora_down = torch.nn.Linear(20, 4)
|
||||
lora_up = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert not org_module.weight.requires_grad
|
||||
|
||||
# Test with requires_grad=True
|
||||
org_module2 = torch.nn.Linear(20, 10)
|
||||
org_module2.weight.requires_grad = True
|
||||
|
||||
initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert org_module2.weight.requires_grad
|
||||
|
||||
569
tests/library/test_network_utils_pissa.py
Normal file
569
tests/library/test_network_utils_pissa.py
Normal file
@@ -0,0 +1,569 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Tuple
|
||||
|
||||
from library.network_utils import convert_pissa_to_standard_lora, initialize_pissa
|
||||
|
||||
|
||||
def generate_synthetic_weights(org_weight, seed=42):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Base random normal distribution
|
||||
weights = torch.randn_like(org_weight)
|
||||
|
||||
# Add structured variance to mimic real-world weight matrices
|
||||
# Techniques to create more realistic weight distributions:
|
||||
|
||||
# 1. Block-wise variation
|
||||
block_size = max(1, org_weight.shape[0] // 4)
|
||||
for i in range(0, org_weight.shape[0], block_size):
|
||||
block_end = min(i + block_size, org_weight.shape[0])
|
||||
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
|
||||
weights[i:block_end, :] *= 1 + block_variation
|
||||
|
||||
# 2. Sparse connectivity simulation
|
||||
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
|
||||
weights *= sparsity_mask.float()
|
||||
|
||||
# 3. Magnitude decay
|
||||
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
|
||||
weights *= magnitude_decay
|
||||
|
||||
# 4. Add structured noise
|
||||
structural_noise = torch.randn_like(org_weight) * 0.1
|
||||
weights += structural_noise
|
||||
|
||||
# Normalize to have similar statistical properties to trained weights
|
||||
weights = (weights - weights.mean()) / weights.std()
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
class TestPissa:
|
||||
"""Test suite for convert_pissa_to_standard_lora function."""
|
||||
|
||||
@pytest.fixture
|
||||
def basic_matrices(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, int]:
|
||||
"""Create basic test matrices with known properties."""
|
||||
torch.manual_seed(42)
|
||||
d_model, rank = 64, 8
|
||||
|
||||
# Create original matrices
|
||||
orig_up = torch.randn(d_model, rank, dtype=torch.float32)
|
||||
orig_down = torch.randn(rank, d_model, dtype=torch.float32)
|
||||
|
||||
# Create trained matrices (slightly different)
|
||||
noise_scale = 0.1
|
||||
trained_up = orig_up + noise_scale * torch.randn_like(orig_up)
|
||||
trained_down = orig_down + noise_scale * torch.randn_like(orig_down)
|
||||
|
||||
return trained_up, trained_down, orig_up, orig_down, rank
|
||||
|
||||
@pytest.fixture
|
||||
def small_matrices(self) -> Tuple[Tensor, Tensor, Tensor, Tensor, int]:
|
||||
"""Create small matrices for easier debugging."""
|
||||
torch.manual_seed(123)
|
||||
d_model, rank = 8, 2
|
||||
|
||||
orig_up = torch.randn(d_model, rank, dtype=torch.float32)
|
||||
orig_down = torch.randn(rank, d_model, dtype=torch.float32)
|
||||
trained_up = orig_up + 0.1 * torch.randn_like(orig_up)
|
||||
trained_down = orig_down + 0.1 * torch.randn_like(orig_down)
|
||||
|
||||
return trained_up, trained_down, orig_up, orig_down, rank
|
||||
|
||||
def test_initialize_pissa_rank_constraints(self):
|
||||
# Test with different rank values
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Test with rank less than min dimension
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test with rank equal to min dimension
|
||||
lora_down = torch.nn.Linear(20, 10)
|
||||
lora_up = torch.nn.Linear(10, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
||||
|
||||
def test_initialize_pissa_rank_limits(self):
|
||||
# Test rank limits
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
|
||||
# Test minimum rank (should work)
|
||||
lora_down_min = torch.nn.Linear(10, 1)
|
||||
lora_up_min = torch.nn.Linear(1, 5)
|
||||
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
||||
|
||||
# Test maximum rank (rank = min(input_dim, output_dim))
|
||||
max_rank = min(10, 5)
|
||||
lora_down_max = torch.nn.Linear(10, max_rank)
|
||||
lora_up_max = torch.nn.Linear(max_rank, 5)
|
||||
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
||||
|
||||
def test_initialize_pissa_basic(self):
|
||||
# Create a simple linear layer
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Create LoRA layers with matching shapes
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
# Store original weight for comparison
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call the initialization function
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify basic properties
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
assert org_module.weight.data is not None
|
||||
|
||||
# Check that the weights have been modified
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
def test_initialize_pissa_with_lowrank(self):
|
||||
# Test with low-rank SVD option
|
||||
org_module = torch.nn.Linear(50, 30)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(50, 5)
|
||||
lora_up = torch.nn.Linear(5, 30)
|
||||
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call with low-rank SVD enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True)
|
||||
|
||||
# Verify weights are changed
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
def test_initialize_pissa_custom_lowrank_params(self):
|
||||
# Test with custom low-rank parameters
|
||||
org_module = torch.nn.Linear(30, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(30, 5)
|
||||
lora_up = torch.nn.Linear(5, 20)
|
||||
|
||||
# Test with custom q value and iterations
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_niter=6)
|
||||
|
||||
# Check basic validity
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
|
||||
def test_initialize_pissa_device_handling(self):
|
||||
# Test different device scenarios
|
||||
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
||||
|
||||
for device in devices:
|
||||
# Create modules on specific device
|
||||
org_module = torch.nn.Linear(10, 5).to(device)
|
||||
lora_down = torch.nn.Linear(10, 2).to(device)
|
||||
lora_up = torch.nn.Linear(2, 5).to(device)
|
||||
|
||||
# Test initialization with explicit device
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
||||
|
||||
# Verify modules are on the correct device
|
||||
assert org_module.weight.data.device.type == device.type
|
||||
assert lora_down.weight.data.device.type == device.type
|
||||
assert lora_up.weight.data.device.type == device.type
|
||||
|
||||
# Test with IPCA
|
||||
if device.type == "cpu": # IPCA might be slow on CPU for large matrices
|
||||
org_module_small = torch.nn.Linear(20, 10).to(device)
|
||||
lora_down_small = torch.nn.Linear(20, 3).to(device)
|
||||
lora_up_small = torch.nn.Linear(3, 10).to(device)
|
||||
|
||||
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device)
|
||||
|
||||
assert org_module_small.weight.data.device.type == device.type
|
||||
|
||||
def test_initialize_pissa_shape_mismatch(self):
|
||||
# Test with shape mismatch to ensure warning is printed
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
|
||||
# Intentionally mismatched shapes to test warning mechanism
|
||||
lora_down = torch.nn.Linear(20, 5) # Different shape
|
||||
lora_up = torch.nn.Linear(3, 15) # Different shape
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
def test_initialize_pissa_scaling(self):
|
||||
# Test different scaling factors
|
||||
scales = [0.0, 0.1, 1.0]
|
||||
|
||||
for scale in scales:
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
|
||||
# Check that the weight modification follows the scaling
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_initialize_pissa_dtype(self):
|
||||
# Test with different data types
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify output dtype matches input
|
||||
assert org_module.weight.dtype == dtype
|
||||
|
||||
def test_initialize_pissa_svd_properties(self):
|
||||
# Verify SVD decomposition properties
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Reconstruct the weight
|
||||
reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
# Check reconstruction is close to original
|
||||
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def test_initialize_pissa_dtype_preservation(self):
|
||||
# Test dtype preservation and conversion
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
lora_down = torch.nn.Linear(10, 2).to(dtype=dtype)
|
||||
lora_up = torch.nn.Linear(2, 5).to(dtype=dtype)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
assert org_module.weight.dtype == dtype
|
||||
assert lora_down.weight.dtype == dtype
|
||||
assert lora_up.weight.dtype == dtype
|
||||
|
||||
# Test with explicit dtype
|
||||
if dtype != torch.float16: # Skip float16 for computational stability in SVD
|
||||
org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32)
|
||||
lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32)
|
||||
lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32)
|
||||
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype)
|
||||
|
||||
# Original module should be converted to specified dtype
|
||||
assert org_module2.weight.dtype == torch.float32
|
||||
|
||||
def test_initialize_pissa_numerical_stability(self):
|
||||
# Test with numerically challenging scenarios
|
||||
scenarios = [
|
||||
torch.randn(20, 10) * 1e-5, # Small values
|
||||
torch.randn(20, 10) * 1e5, # Large values
|
||||
torch.ones(20, 10), # Uniform values
|
||||
]
|
||||
|
||||
for i, weight_matrix in enumerate(scenarios):
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = weight_matrix
|
||||
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test IPCA as well
|
||||
lora_down_ipca = torch.nn.Linear(20, 3)
|
||||
lora_up_ipca = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
||||
|
||||
def test_initialize_pissa_scale_effects(self):
|
||||
# Test effect of different scaling factors
|
||||
org_module = torch.nn.Linear(15, 10)
|
||||
original_weight = torch.randn_like(org_module.weight.data)
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
# Try different scales
|
||||
scales = [0.0, 0.01, 0.1, 1.0]
|
||||
|
||||
for scale in scales:
|
||||
# Reset to original weights
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(15, 4)
|
||||
lora_up = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4)
|
||||
|
||||
# Verify weight modification proportional to scale
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
|
||||
# Approximate check of scaling effect
|
||||
if scale == 0.0:
|
||||
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
else:
|
||||
# For non-zero scales, verify the magnitude of change is proportional to scale
|
||||
assert weight_diff.abs().sum() > 0
|
||||
|
||||
# Do a second run with double the scale
|
||||
org_module2 = torch.nn.Linear(15, 10)
|
||||
org_module2.weight.data = original_weight.clone()
|
||||
|
||||
lora_down2 = torch.nn.Linear(15, 4)
|
||||
lora_up2 = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4)
|
||||
|
||||
weight_diff2 = original_weight - org_module2.weight.data
|
||||
|
||||
# The ratio of differences should be approximately 2
|
||||
# (allowing for numerical precision issues)
|
||||
ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10)
|
||||
assert 1.9 < ratio < 2.1
|
||||
|
||||
def test_initialize_pissa_large_matrix_performance(self):
|
||||
# Test with a large matrix to ensure it works well
|
||||
# This is particularly relevant for IPCA mode
|
||||
|
||||
# Skip if running on CPU to avoid long test times
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Skipping large matrix test on CPU")
|
||||
|
||||
org_module = torch.nn.Linear(1000, 500)
|
||||
org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1
|
||||
|
||||
lora_down = torch.nn.Linear(1000, 16)
|
||||
lora_up = torch.nn.Linear(16, 500)
|
||||
|
||||
# Test standard approach
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Standard SVD failed on large matrix: {e}")
|
||||
|
||||
# Test IPCA approach
|
||||
lora_down_ipca = torch.nn.Linear(1000, 16)
|
||||
lora_up_ipca = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16)
|
||||
except Exception as e:
|
||||
pytest.fail(f"IPCA approach failed on large matrix: {e}")
|
||||
|
||||
# Test IPCA with lowrank
|
||||
lora_down_both = torch.nn.Linear(1000, 16)
|
||||
lora_up_both = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}")
|
||||
|
||||
def test_initialize_pissa_requires_grad_preservation(self):
|
||||
# Test that requires_grad property is preserved
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.requires_grad = False
|
||||
|
||||
lora_down = torch.nn.Linear(20, 4)
|
||||
lora_up = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert not org_module.weight.requires_grad
|
||||
|
||||
# Test with requires_grad=True
|
||||
org_module2 = torch.nn.Linear(20, 10)
|
||||
org_module2.weight.requires_grad = True
|
||||
|
||||
initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert org_module2.weight.requires_grad
|
||||
|
||||
def test_basic_functionality(self, basic_matrices):
|
||||
"""Test that the function runs without errors and returns expected shapes."""
|
||||
trained_up, trained_down, orig_up, orig_down, rank = basic_matrices
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Check output types
|
||||
assert isinstance(new_up, torch.Tensor)
|
||||
assert isinstance(new_down, torch.Tensor)
|
||||
|
||||
# Check shapes - should be compatible for matrix multiplication
|
||||
d_model = trained_up.shape[0]
|
||||
expected_rank = min(rank * 2, min(d_model, trained_down.shape[1]))
|
||||
|
||||
assert new_up.shape == torch.Size([d_model, expected_rank])
|
||||
assert new_down.shape == (expected_rank, trained_down.shape[1])
|
||||
|
||||
def test_delta_preservation(self, basic_matrices):
|
||||
"""Test that the delta weight is preserved in the LoRA decomposition."""
|
||||
trained_up, trained_down, orig_up, orig_down, rank = basic_matrices
|
||||
|
||||
# Calculate original delta
|
||||
original_delta = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# Convert to LoRA
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Reconstruct delta from LoRA matrices
|
||||
reconstructed_delta = new_up @ new_down
|
||||
|
||||
# Check that reconstruction approximates original delta
|
||||
# (Note: some information loss is expected due to rank reduction)
|
||||
relative_error = torch.norm(original_delta - reconstructed_delta) / torch.norm(original_delta)
|
||||
assert relative_error < 0.5 # Allow some approximation error
|
||||
|
||||
def test_rank_handling(self, small_matrices):
|
||||
"""Test various rank scenarios."""
|
||||
trained_up, trained_down, orig_up, orig_down, base_rank = small_matrices
|
||||
d_model = trained_up.shape[0]
|
||||
|
||||
# Test with rank that would exceed matrix dimensions
|
||||
large_rank = d_model + 5
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, large_rank)
|
||||
|
||||
# Should not exceed available singular values
|
||||
max_possible_rank = min(d_model, trained_down.shape[1])
|
||||
assert new_up.shape[1] <= max_possible_rank
|
||||
assert new_down.shape[0] <= max_possible_rank
|
||||
|
||||
def test_zero_delta(self):
|
||||
"""Test behavior when trained and original matrices are identical."""
|
||||
torch.manual_seed(456)
|
||||
d_model, rank = 16, 4
|
||||
|
||||
# Create identical matrices
|
||||
orig_up = torch.randn(d_model, rank, dtype=torch.float32)
|
||||
orig_down = torch.randn(rank, d_model, dtype=torch.float32)
|
||||
trained_up = orig_up.clone()
|
||||
trained_down = orig_down.clone()
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Reconstructed delta should be close to zero
|
||||
reconstructed_delta = new_up @ new_down
|
||||
assert torch.allclose(reconstructed_delta, torch.zeros_like(reconstructed_delta), atol=1e-6)
|
||||
|
||||
def test_different_devices(self, basic_matrices):
|
||||
"""Test that the function handles different device placement correctly."""
|
||||
trained_up, trained_down, orig_up, orig_down, rank = basic_matrices
|
||||
|
||||
# Test with CPU tensors
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Results should be on the same device as input
|
||||
assert new_up.device == trained_up.device
|
||||
assert new_down.device == trained_up.device
|
||||
|
||||
def test_gradient_disabled(self, basic_matrices):
|
||||
"""Test that gradients are properly disabled."""
|
||||
trained_up, trained_down, orig_up, orig_down, rank = basic_matrices
|
||||
|
||||
# Enable gradients on inputs
|
||||
trained_up.requires_grad_(True)
|
||||
trained_down.requires_grad_(True)
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Outputs should not require gradients due to torch.no_grad()
|
||||
assert not new_up.requires_grad
|
||||
assert not new_down.requires_grad
|
||||
|
||||
def test_dtype_consistency(self, basic_matrices):
|
||||
"""Test that output dtypes are consistent."""
|
||||
trained_up, trained_down, orig_up, orig_down, rank = basic_matrices
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Should maintain float32 dtype
|
||||
assert new_up.dtype == torch.float32
|
||||
assert new_down.dtype == torch.float32
|
||||
|
||||
def test_mathematical_properties(self, small_matrices):
|
||||
"""Test mathematical properties of the SVD decomposition."""
|
||||
trained_up, trained_down, orig_up, orig_down, rank = small_matrices
|
||||
|
||||
# Calculate delta manually
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# The decomposition should satisfy: new_up @ new_down ≈ low-rank approximation of delta_w
|
||||
reconstructed = new_up @ new_down
|
||||
|
||||
# Check that reconstruction has expected rank
|
||||
actual_rank = torch.linalg.matrix_rank(reconstructed).item()
|
||||
expected_max_rank = min(rank * 2, min(delta_w.shape))
|
||||
assert actual_rank <= expected_max_rank
|
||||
|
||||
@pytest.mark.parametrize("rank", [1, 4, 8, 16])
|
||||
def test_different_ranks(self, rank):
|
||||
"""Test the function with different rank values."""
|
||||
torch.manual_seed(789)
|
||||
d_model = 32
|
||||
|
||||
orig_up = torch.randn(d_model, rank, dtype=torch.float32)
|
||||
orig_down = torch.randn(rank, d_model, dtype=torch.float32)
|
||||
trained_up = orig_up + 0.1 * torch.randn_like(orig_up)
|
||||
trained_down = orig_down + 0.1 * torch.randn_like(orig_down)
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# Should handle all rank values gracefully
|
||||
assert new_up.shape[0] == d_model
|
||||
assert new_down.shape[1] == d_model
|
||||
assert new_up.shape[1] == new_down.shape[0] # Compatible for multiplication
|
||||
|
||||
def test_edge_case_single_rank(self):
|
||||
"""Test with minimal rank (rank=1)."""
|
||||
torch.manual_seed(101)
|
||||
d_model, rank = 8, 1
|
||||
|
||||
orig_up = torch.randn(d_model, rank, dtype=torch.float32)
|
||||
orig_down = torch.randn(rank, d_model, dtype=torch.float32)
|
||||
trained_up = orig_up + 0.2 * torch.randn_like(orig_up)
|
||||
trained_down = orig_down + 0.2 * torch.randn_like(orig_down)
|
||||
|
||||
new_up, new_down = convert_pissa_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank)
|
||||
|
||||
# With rank=1, output rank should be 2 (rank * 2)
|
||||
expected_rank = min(2, min(d_model, d_model))
|
||||
assert new_up.shape[1] <= expected_rank
|
||||
assert new_down.shape[0] <= expected_rank
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the tests
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -4,109 +4,87 @@ import torch
|
||||
from library.network_utils import convert_urae_to_standard_lora
|
||||
|
||||
|
||||
class TestConvertURAEToStandardLoRA:
|
||||
class TestURAE:
|
||||
@pytest.fixture
|
||||
def sample_matrices(self):
|
||||
"""Create sample matrices for testing"""
|
||||
# Original up matrix (4x2)
|
||||
orig_up = torch.tensor([
|
||||
[1.0, 2.0],
|
||||
[3.0, 4.0],
|
||||
[5.0, 6.0],
|
||||
[7.0, 8.0]
|
||||
])
|
||||
|
||||
orig_up = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
|
||||
|
||||
# Original down matrix (2x6)
|
||||
orig_down = torch.tensor([
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
||||
])
|
||||
|
||||
orig_down = torch.tensor([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.7, 0.8, 0.9, 1.0, 1.1, 1.2]])
|
||||
|
||||
# Trained up matrix (4x2) - same shape as orig_up but with changed values
|
||||
trained_up = torch.tensor([
|
||||
[1.1, 2.1],
|
||||
[3.1, 4.1],
|
||||
[5.1, 6.1],
|
||||
[7.1, 8.1]
|
||||
])
|
||||
|
||||
trained_up = torch.tensor([[1.1, 2.1], [3.1, 4.1], [5.1, 6.1], [7.1, 8.1]])
|
||||
|
||||
# Trained down matrix (2x6) - same shape as orig_down but with changed values
|
||||
trained_down = torch.tensor([
|
||||
[0.15, 0.25, 0.35, 0.45, 0.55, 0.65],
|
||||
[0.75, 0.85, 0.95, 1.05, 1.15, 1.25]
|
||||
])
|
||||
|
||||
return {
|
||||
'orig_up': orig_up,
|
||||
'orig_down': orig_down,
|
||||
'trained_up': trained_up,
|
||||
'trained_down': trained_down
|
||||
}
|
||||
trained_down = torch.tensor([[0.15, 0.25, 0.35, 0.45, 0.55, 0.65], [0.75, 0.85, 0.95, 1.05, 1.15, 1.25]])
|
||||
|
||||
return {"orig_up": orig_up, "orig_down": orig_down, "trained_up": trained_up, "trained_down": trained_down}
|
||||
|
||||
def test_basic_conversion(self, sample_matrices):
|
||||
"""Test the basic functionality of convert_urae_to_standard_lora"""
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down']
|
||||
sample_matrices["trained_up"], sample_matrices["trained_down"], sample_matrices["orig_up"], sample_matrices["orig_down"]
|
||||
)
|
||||
|
||||
|
||||
# Check shapes
|
||||
assert lora_up.shape[0] == sample_matrices['trained_up'].shape[0] # Same number of rows as trained_up
|
||||
assert lora_up.shape[1] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
|
||||
assert lora_down.shape[0] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
|
||||
assert lora_down.shape[1] == sample_matrices['trained_down'].shape[1] # Same number of columns as trained_down
|
||||
|
||||
assert lora_up.shape[0] == sample_matrices["trained_up"].shape[0] # Same number of rows as trained_up
|
||||
assert lora_up.shape[1] == sample_matrices["trained_up"].shape[1] # Same rank as trained_up
|
||||
assert lora_down.shape[0] == sample_matrices["trained_up"].shape[1] # Same rank as trained_up
|
||||
assert lora_down.shape[1] == sample_matrices["trained_down"].shape[1] # Same number of columns as trained_down
|
||||
|
||||
# Check alpha is a reasonable value
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
|
||||
|
||||
# Check that lora_up @ lora_down approximates the weight delta
|
||||
delta = (sample_matrices['trained_up'] @ sample_matrices['trained_down']) - (sample_matrices['orig_up'] @ sample_matrices['orig_down'])
|
||||
|
||||
delta = (sample_matrices["trained_up"] @ sample_matrices["trained_down"]) - (
|
||||
sample_matrices["orig_up"] @ sample_matrices["orig_down"]
|
||||
)
|
||||
|
||||
# The approximation should be close in Frobenius norm after scaling
|
||||
lora_effect = lora_up @ lora_down
|
||||
delta_norm = torch.norm(delta, p="fro").item()
|
||||
lora_norm = torch.norm(lora_effect, p="fro").item()
|
||||
|
||||
|
||||
# Either they are close, or the alpha scaling brings them close
|
||||
scaled_lora_effect = (alpha / sample_matrices['trained_up'].shape[1]) * lora_effect
|
||||
scaled_lora_effect = (alpha / sample_matrices["trained_up"].shape[1]) * lora_effect
|
||||
scaled_lora_norm = torch.norm(scaled_lora_effect, p="fro").item()
|
||||
|
||||
|
||||
# At least one of these should be true
|
||||
assert abs(delta_norm - lora_norm) < 1e-4 or abs(delta_norm - scaled_lora_norm) < 1e-4
|
||||
|
||||
def test_specified_rank(self, sample_matrices):
|
||||
"""Test conversion with a specified rank"""
|
||||
new_rank = 1 # Lower than trained_up's rank of 2
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
rank=new_rank
|
||||
sample_matrices["trained_up"],
|
||||
sample_matrices["trained_down"],
|
||||
sample_matrices["orig_up"],
|
||||
sample_matrices["orig_down"],
|
||||
rank=new_rank,
|
||||
)
|
||||
|
||||
|
||||
# Check that the new rank is used
|
||||
assert lora_up.shape[1] == new_rank
|
||||
assert lora_down.shape[0] == new_rank
|
||||
|
||||
|
||||
# Should still produce a reasonable alpha
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
|
||||
def test_with_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with a specified initial alpha"""
|
||||
initial_alpha = 16.0
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha
|
||||
sample_matrices["trained_up"],
|
||||
sample_matrices["trained_down"],
|
||||
sample_matrices["orig_up"],
|
||||
sample_matrices["orig_down"],
|
||||
initial_alpha=initial_alpha,
|
||||
)
|
||||
|
||||
|
||||
# Alpha should be influenced by initial_alpha but may be adjusted
|
||||
# Since we're using same rank, should be reasonably close to initial_alpha
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
@@ -116,30 +94,30 @@ class TestConvertURAEToStandardLoRA:
|
||||
def test_large_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with a very large initial alpha that should be capped"""
|
||||
initial_alpha = 2000.0 # Larger than the 1024.0 cap
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha
|
||||
sample_matrices["trained_up"],
|
||||
sample_matrices["trained_down"],
|
||||
sample_matrices["orig_up"],
|
||||
sample_matrices["orig_down"],
|
||||
initial_alpha=initial_alpha,
|
||||
)
|
||||
|
||||
|
||||
# Alpha should be capped at 1024.0
|
||||
assert alpha <= 1024.0
|
||||
|
||||
def test_very_small_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with a very small initial alpha that should be floored"""
|
||||
initial_alpha = 0.01 # Smaller than the 0.1 floor
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha
|
||||
sample_matrices["trained_up"],
|
||||
sample_matrices["trained_down"],
|
||||
sample_matrices["orig_up"],
|
||||
sample_matrices["orig_down"],
|
||||
initial_alpha=initial_alpha,
|
||||
)
|
||||
|
||||
|
||||
# Alpha should be floored at 0.1
|
||||
assert alpha >= 0.1
|
||||
|
||||
@@ -147,22 +125,22 @@ class TestConvertURAEToStandardLoRA:
|
||||
"""Test conversion with both rank change and initial alpha"""
|
||||
initial_alpha = 16.0
|
||||
new_rank = 1 # Half of original rank 2
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
sample_matrices["trained_up"],
|
||||
sample_matrices["trained_down"],
|
||||
sample_matrices["orig_up"],
|
||||
sample_matrices["orig_down"],
|
||||
initial_alpha=initial_alpha,
|
||||
rank=new_rank
|
||||
rank=new_rank,
|
||||
)
|
||||
|
||||
|
||||
# Check shapes
|
||||
assert lora_up.shape[1] == new_rank
|
||||
assert lora_down.shape[0] == new_rank
|
||||
|
||||
|
||||
# Alpha should be adjusted for the rank change (approx halved in this case)
|
||||
expected_alpha = initial_alpha * (new_rank / sample_matrices['trained_up'].shape[1])
|
||||
expected_alpha = initial_alpha * (new_rank / sample_matrices["trained_up"].shape[1])
|
||||
# Allow some tolerance for adjustments from norm-based capping
|
||||
assert abs(alpha - expected_alpha) <= expected_alpha * 4.0 or alpha >= 0.1
|
||||
|
||||
@@ -170,47 +148,37 @@ class TestConvertURAEToStandardLoRA:
|
||||
"""Test conversion when delta is zero"""
|
||||
# Create matrices where the delta will be zero
|
||||
dim_in, rank, dim_out = 4, 2, 6
|
||||
|
||||
|
||||
# Create identical matrices for original and trained
|
||||
orig_up = torch.randn(dim_in, rank)
|
||||
orig_down = torch.randn(rank, dim_out)
|
||||
trained_up = orig_up.clone()
|
||||
trained_down = orig_down.clone()
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
trained_up,
|
||||
trained_down,
|
||||
orig_up,
|
||||
orig_down
|
||||
)
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(trained_up, trained_down, orig_up, orig_down)
|
||||
|
||||
# Should still return matrices of correct shape
|
||||
assert lora_up.shape == (dim_in, rank)
|
||||
assert lora_down.shape == (rank, dim_out)
|
||||
|
||||
|
||||
# Alpha should be at least the minimum
|
||||
assert alpha >= 0.1
|
||||
|
||||
def test_large_dimensions(self):
|
||||
"""Test with larger matrix dimensions"""
|
||||
dim_in, rank, dim_out = 100, 8, 200
|
||||
|
||||
|
||||
orig_up = torch.randn(dim_in, rank)
|
||||
orig_down = torch.randn(rank, dim_out)
|
||||
trained_up = orig_up + 0.01 * torch.randn(dim_in, rank) # Small perturbation
|
||||
trained_down = orig_down + 0.01 * torch.randn(rank, dim_out) # Small perturbation
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
trained_up,
|
||||
trained_down,
|
||||
orig_up,
|
||||
orig_down
|
||||
)
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(trained_up, trained_down, orig_up, orig_down)
|
||||
|
||||
# Check shapes
|
||||
assert lora_up.shape == (dim_in, rank)
|
||||
assert lora_down.shape == (rank, dim_out)
|
||||
|
||||
|
||||
# Should produce a reasonable alpha
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
|
||||
@@ -218,23 +186,17 @@ class TestConvertURAEToStandardLoRA:
|
||||
"""Test when requested rank exceeds available singular values"""
|
||||
# Small matrices with limited rank
|
||||
dim_in, rank, dim_out = 3, 2, 3
|
||||
|
||||
|
||||
orig_up = torch.randn(dim_in, rank)
|
||||
orig_down = torch.randn(rank, dim_out)
|
||||
trained_up = orig_up + 0.1 * torch.randn(dim_in, rank)
|
||||
trained_down = orig_down + 0.1 * torch.randn(rank, dim_out)
|
||||
|
||||
|
||||
# Request rank larger than possible
|
||||
too_large_rank = 10
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
trained_up,
|
||||
trained_down,
|
||||
orig_up,
|
||||
orig_down,
|
||||
rank=too_large_rank
|
||||
)
|
||||
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(trained_up, trained_down, orig_up, orig_down, rank=too_large_rank)
|
||||
|
||||
# Rank should be limited to min(dim_in, dim_out, S.size)
|
||||
max_possible_rank = min(dim_in, dim_out)
|
||||
assert lora_up.shape[1] <= max_possible_rank
|
||||
|
||||
Reference in New Issue
Block a user