mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
390 lines
14 KiB
Python
390 lines
14 KiB
Python
import torch
|
|
import pytest
|
|
from library.network_utils import initialize_pissa
|
|
|
|
|
|
def generate_synthetic_weights(org_weight, seed=42):
|
|
generator = torch.manual_seed(seed)
|
|
|
|
# Base random normal distribution
|
|
weights = torch.randn_like(org_weight)
|
|
|
|
# Add structured variance to mimic real-world weight matrices
|
|
# Techniques to create more realistic weight distributions:
|
|
|
|
# 1. Block-wise variation
|
|
block_size = max(1, org_weight.shape[0] // 4)
|
|
for i in range(0, org_weight.shape[0], block_size):
|
|
block_end = min(i + block_size, org_weight.shape[0])
|
|
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
|
|
weights[i:block_end, :] *= 1 + block_variation
|
|
|
|
# 2. Sparse connectivity simulation
|
|
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
|
|
weights *= sparsity_mask.float()
|
|
|
|
# 3. Magnitude decay
|
|
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
|
|
weights *= magnitude_decay
|
|
|
|
# 4. Add structured noise
|
|
structural_noise = torch.randn_like(org_weight) * 0.1
|
|
weights += structural_noise
|
|
|
|
# Normalize to have similar statistical properties to trained weights
|
|
weights = (weights - weights.mean()) / weights.std()
|
|
|
|
return weights
|
|
|
|
|
|
def test_initialize_pissa_rank_constraints():
|
|
# Test with different rank values
|
|
org_module = torch.nn.Linear(20, 10)
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
|
|
torch.nn.init.xavier_uniform_(org_module.weight)
|
|
torch.nn.init.zeros_(org_module.bias)
|
|
|
|
# Test with rank less than min dimension
|
|
lora_down = torch.nn.Linear(20, 3)
|
|
lora_up = torch.nn.Linear(3, 10)
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
|
|
|
# Test with rank equal to min dimension
|
|
lora_down = torch.nn.Linear(20, 10)
|
|
lora_up = torch.nn.Linear(10, 10)
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
|
|
|
|
|
def test_initialize_pissa_rank_limits():
|
|
# Test rank limits
|
|
org_module = torch.nn.Linear(10, 5)
|
|
|
|
# Test minimum rank (should work)
|
|
lora_down_min = torch.nn.Linear(10, 1)
|
|
lora_up_min = torch.nn.Linear(1, 5)
|
|
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
|
|
|
# Test maximum rank (rank = min(input_dim, output_dim))
|
|
max_rank = min(10, 5)
|
|
lora_down_max = torch.nn.Linear(10, max_rank)
|
|
lora_up_max = torch.nn.Linear(max_rank, 5)
|
|
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
|
|
|
|
|
def test_initialize_pissa_basic():
|
|
# Create a simple linear layer
|
|
org_module = torch.nn.Linear(10, 5)
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
|
|
torch.nn.init.xavier_uniform_(org_module.weight)
|
|
torch.nn.init.zeros_(org_module.bias)
|
|
|
|
# Create LoRA layers with matching shapes
|
|
lora_down = torch.nn.Linear(10, 2)
|
|
lora_up = torch.nn.Linear(2, 5)
|
|
|
|
# Store original weight for comparison
|
|
original_weight = org_module.weight.data.clone()
|
|
|
|
# Call the initialization function
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
|
|
|
# Verify basic properties
|
|
assert lora_down.weight.data is not None
|
|
assert lora_up.weight.data is not None
|
|
assert org_module.weight.data is not None
|
|
|
|
# Check that the weights have been modified
|
|
assert not torch.equal(original_weight, org_module.weight.data)
|
|
|
|
|
|
def test_initialize_pissa_with_lowrank():
|
|
# Test with low-rank SVD option
|
|
org_module = torch.nn.Linear(50, 30)
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
|
|
lora_down = torch.nn.Linear(50, 5)
|
|
lora_up = torch.nn.Linear(5, 30)
|
|
|
|
original_weight = org_module.weight.data.clone()
|
|
|
|
# Call with low-rank SVD enabled
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True)
|
|
|
|
# Verify weights are changed
|
|
assert not torch.equal(original_weight, org_module.weight.data)
|
|
|
|
|
|
def test_initialize_pissa_custom_lowrank_params():
|
|
# Test with custom low-rank parameters
|
|
org_module = torch.nn.Linear(30, 20)
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
|
|
lora_down = torch.nn.Linear(30, 5)
|
|
lora_up = torch.nn.Linear(5, 20)
|
|
|
|
# Test with custom q value and iterations
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_q=12, lowrank_niter=6)
|
|
|
|
# Check basic validity
|
|
assert lora_down.weight.data is not None
|
|
assert lora_up.weight.data is not None
|
|
|
|
|
|
def test_initialize_pissa_device_handling():
|
|
# Test different device scenarios
|
|
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
|
|
|
for device in devices:
|
|
# Create modules on specific device
|
|
org_module = torch.nn.Linear(10, 5).to(device)
|
|
lora_down = torch.nn.Linear(10, 2).to(device)
|
|
lora_up = torch.nn.Linear(2, 5).to(device)
|
|
|
|
# Test initialization with explicit device
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
|
|
|
# Verify modules are on the correct device
|
|
assert org_module.weight.data.device.type == device.type
|
|
assert lora_down.weight.data.device.type == device.type
|
|
assert lora_up.weight.data.device.type == device.type
|
|
|
|
# Test with IPCA
|
|
if device.type == "cpu": # IPCA might be slow on CPU for large matrices
|
|
org_module_small = torch.nn.Linear(20, 10).to(device)
|
|
lora_down_small = torch.nn.Linear(20, 3).to(device)
|
|
lora_up_small = torch.nn.Linear(3, 10).to(device)
|
|
|
|
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device)
|
|
|
|
assert org_module_small.weight.data.device.type == device.type
|
|
|
|
|
|
def test_initialize_pissa_shape_mismatch():
|
|
# Test with shape mismatch to ensure warning is printed
|
|
org_module = torch.nn.Linear(20, 10)
|
|
|
|
# Intentionally mismatched shapes to test warning mechanism
|
|
lora_down = torch.nn.Linear(20, 5) # Different shape
|
|
lora_up = torch.nn.Linear(3, 15) # Different shape
|
|
|
|
with pytest.warns(UserWarning):
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
|
|
|
|
|
def test_initialize_pissa_scaling():
|
|
# Test different scaling factors
|
|
scales = [0.0, 0.1, 1.0]
|
|
|
|
for scale in scales:
|
|
org_module = torch.nn.Linear(10, 5)
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
original_weight = org_module.weight.data.clone()
|
|
|
|
lora_down = torch.nn.Linear(10, 2)
|
|
lora_up = torch.nn.Linear(2, 5)
|
|
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
|
|
|
# Check that the weight modification follows the scaling
|
|
weight_diff = original_weight - org_module.weight.data
|
|
expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data)
|
|
|
|
torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4)
|
|
|
|
|
|
def test_initialize_pissa_dtype():
|
|
# Test with different data types
|
|
dtypes = [torch.float16, torch.float32, torch.float64]
|
|
|
|
for dtype in dtypes:
|
|
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
|
|
lora_down = torch.nn.Linear(10, 2)
|
|
lora_up = torch.nn.Linear(2, 5)
|
|
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
|
|
|
# Verify output dtype matches input
|
|
assert org_module.weight.dtype == dtype
|
|
|
|
|
|
def test_initialize_pissa_svd_properties():
|
|
# Verify SVD decomposition properties
|
|
org_module = torch.nn.Linear(20, 10)
|
|
lora_down = torch.nn.Linear(20, 3)
|
|
lora_up = torch.nn.Linear(3, 10)
|
|
|
|
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
|
original_weight = org_module.weight.data.clone()
|
|
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
|
|
|
# Reconstruct the weight
|
|
reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data)
|
|
|
|
# Check reconstruction is close to original
|
|
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
|
|
|
|
|
|
def test_initialize_pissa_dtype_preservation():
|
|
# Test dtype preservation and conversion
|
|
dtypes = [torch.float16, torch.float32, torch.float64]
|
|
|
|
for dtype in dtypes:
|
|
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
|
lora_down = torch.nn.Linear(10, 2).to(dtype=dtype)
|
|
lora_up = torch.nn.Linear(2, 5).to(dtype=dtype)
|
|
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
|
|
|
assert org_module.weight.dtype == dtype
|
|
assert lora_down.weight.dtype == dtype
|
|
assert lora_up.weight.dtype == dtype
|
|
|
|
# Test with explicit dtype
|
|
if dtype != torch.float16: # Skip float16 for computational stability in SVD
|
|
org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32)
|
|
lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32)
|
|
lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32)
|
|
|
|
initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype)
|
|
|
|
# Original module should be converted to specified dtype
|
|
assert org_module2.weight.dtype == torch.float32
|
|
|
|
|
|
def test_initialize_pissa_numerical_stability():
|
|
# Test with numerically challenging scenarios
|
|
scenarios = [
|
|
torch.randn(20, 10) * 1e-5, # Small values
|
|
torch.randn(20, 10) * 1e5, # Large values
|
|
torch.ones(20, 10), # Uniform values
|
|
]
|
|
|
|
for i, weight_matrix in enumerate(scenarios):
|
|
org_module = torch.nn.Linear(20, 10)
|
|
org_module.weight.data = weight_matrix
|
|
|
|
lora_down = torch.nn.Linear(20, 3)
|
|
lora_up = torch.nn.Linear(3, 10)
|
|
|
|
try:
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
|
|
|
# Test IPCA as well
|
|
lora_down_ipca = torch.nn.Linear(20, 3)
|
|
lora_up_ipca = torch.nn.Linear(3, 10)
|
|
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3)
|
|
except Exception as e:
|
|
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
|
|
|
|
|
def test_initialize_pissa_scale_effects():
|
|
# Test effect of different scaling factors
|
|
org_module = torch.nn.Linear(15, 10)
|
|
original_weight = torch.randn_like(org_module.weight.data)
|
|
org_module.weight.data = original_weight.clone()
|
|
|
|
# Try different scales
|
|
scales = [0.0, 0.01, 0.1, 1.0]
|
|
|
|
for scale in scales:
|
|
# Reset to original weights
|
|
org_module.weight.data = original_weight.clone()
|
|
|
|
lora_down = torch.nn.Linear(15, 4)
|
|
lora_up = torch.nn.Linear(4, 10)
|
|
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4)
|
|
|
|
# Verify weight modification proportional to scale
|
|
weight_diff = original_weight - org_module.weight.data
|
|
|
|
# Approximate check of scaling effect
|
|
if scale == 0.0:
|
|
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
|
else:
|
|
# For non-zero scales, verify the magnitude of change is proportional to scale
|
|
assert weight_diff.abs().sum() > 0
|
|
|
|
# Do a second run with double the scale
|
|
org_module2 = torch.nn.Linear(15, 10)
|
|
org_module2.weight.data = original_weight.clone()
|
|
|
|
lora_down2 = torch.nn.Linear(15, 4)
|
|
lora_up2 = torch.nn.Linear(4, 10)
|
|
|
|
initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4)
|
|
|
|
weight_diff2 = original_weight - org_module2.weight.data
|
|
|
|
# The ratio of differences should be approximately 2
|
|
# (allowing for numerical precision issues)
|
|
ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10)
|
|
assert 1.9 < ratio < 2.1
|
|
|
|
|
|
def test_initialize_pissa_large_matrix_performance():
|
|
# Test with a large matrix to ensure it works well
|
|
# This is particularly relevant for IPCA mode
|
|
|
|
# Skip if running on CPU to avoid long test times
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("Skipping large matrix test on CPU")
|
|
|
|
org_module = torch.nn.Linear(1000, 500)
|
|
org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1
|
|
|
|
lora_down = torch.nn.Linear(1000, 16)
|
|
lora_up = torch.nn.Linear(16, 500)
|
|
|
|
# Test standard approach
|
|
try:
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16)
|
|
except Exception as e:
|
|
pytest.fail(f"Standard SVD failed on large matrix: {e}")
|
|
|
|
# Test IPCA approach
|
|
lora_down_ipca = torch.nn.Linear(1000, 16)
|
|
lora_up_ipca = torch.nn.Linear(16, 500)
|
|
|
|
try:
|
|
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16)
|
|
except Exception as e:
|
|
pytest.fail(f"IPCA approach failed on large matrix: {e}")
|
|
|
|
# Test IPCA with lowrank
|
|
lora_down_both = torch.nn.Linear(1000, 16)
|
|
lora_up_both = torch.nn.Linear(16, 500)
|
|
|
|
try:
|
|
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True)
|
|
except Exception as e:
|
|
pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}")
|
|
|
|
|
|
def test_initialize_pissa_requires_grad_preservation():
|
|
# Test that requires_grad property is preserved
|
|
org_module = torch.nn.Linear(20, 10)
|
|
org_module.weight.requires_grad = False
|
|
|
|
lora_down = torch.nn.Linear(20, 4)
|
|
lora_up = torch.nn.Linear(4, 10)
|
|
|
|
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4)
|
|
|
|
# Check requires_grad is preserved
|
|
assert not org_module.weight.requires_grad
|
|
|
|
# Test with requires_grad=True
|
|
org_module2 = torch.nn.Linear(20, 10)
|
|
org_module2.weight.requires_grad = True
|
|
|
|
initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4)
|
|
|
|
# Check requires_grad is preserved
|
|
assert org_module2.weight.requires_grad
|