mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
WIP: Updated PiSSA and URAE initialization
This commit is contained in:
@@ -121,51 +121,52 @@ def initialize_urae(
|
||||
# Move original weight to chosen device and use float32 for numerical stability
|
||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||
|
||||
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
|
||||
if use_ipca:
|
||||
ipca = IncrementalPCA(
|
||||
n_components=None,
|
||||
batch_size=1024,
|
||||
lowrank=use_lowrank,
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
|
||||
lowrank_niter=lowrank_niter,
|
||||
lowrank_seed=lowrank_seed,
|
||||
)
|
||||
ipca.fit(weight)
|
||||
with torch.autocast(device.type), torch.no_grad():
|
||||
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
|
||||
if use_ipca:
|
||||
ipca = IncrementalPCA(
|
||||
n_components=None,
|
||||
batch_size=1024,
|
||||
lowrank=use_lowrank,
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
|
||||
lowrank_niter=lowrank_niter,
|
||||
lowrank_seed=lowrank_seed,
|
||||
)
|
||||
ipca.fit(weight)
|
||||
|
||||
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
|
||||
S_full = ipca.singular_values_
|
||||
V_full = ipca.components_.T # Shape: [out_features, total_rank]
|
||||
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
|
||||
S_full = ipca.singular_values_
|
||||
V_full = ipca.components_.T # Shape: [out_features, total_rank]
|
||||
|
||||
# Get identity matrix to transform for right singular vectors
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
|
||||
# Get identity matrix to transform for right singular vectors
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
|
||||
|
||||
# Extract the last 'rank' components (the minor/smallest ones)
|
||||
Sr = S_full[-rank:]
|
||||
Vr = V_full[:, -rank:]
|
||||
Uhr = Uhr_full[-rank:]
|
||||
# Extract the last 'rank' components (the minor/smallest ones)
|
||||
Sr = S_full[-rank:]
|
||||
Vr = V_full[:, -rank:]
|
||||
Uhr = Uhr_full[-rank:]
|
||||
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
else:
|
||||
# Direct SVD approach
|
||||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
else:
|
||||
# Direct SVD approach
|
||||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# Extract the minor components (smallest singular values)
|
||||
Sr = S[-rank:]
|
||||
Vr = U[:, -rank:]
|
||||
Uhr = Vh[-rank:]
|
||||
# Extract the minor components (smallest singular values)
|
||||
Sr = S[-rank:]
|
||||
Vr = U[:, -rank:]
|
||||
Uhr = Vh[-rank:]
|
||||
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
|
||||
# Create the low-rank adapter matrices by splitting the minor components
|
||||
# Down matrix: scaled right singular vectors with singular values
|
||||
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
# Create the low-rank adapter matrices by splitting the minor components
|
||||
# Down matrix: scaled right singular vectors with singular values
|
||||
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
|
||||
# Up matrix: scaled left singular vectors with singular values
|
||||
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
# Up matrix: scaled left singular vectors with singular values
|
||||
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Assign to LoRA modules
|
||||
lora_down.weight.data = down_matrix.to(device=device, dtype=dtype)
|
||||
@@ -223,7 +224,8 @@ def initialize_pissa(
|
||||
|
||||
# We need to get Uhr from transforming an identity matrix
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
Uhr = ipca.transform(identity).T # [rank, in_features]
|
||||
with torch.autocast(device.type, dtype=torch.float64):
|
||||
Uhr = ipca.transform(identity).T # [rank, in_features]
|
||||
|
||||
elif use_lowrank:
|
||||
# Use low-rank SVD approximation which is faster
|
||||
@@ -248,9 +250,11 @@ def initialize_pissa(
|
||||
Sr /= rank
|
||||
Uhr = Uh[:rank]
|
||||
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
# Uhr may be in higher precision
|
||||
with torch.autocast(device.type, dtype=Uhr.dtype):
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
@@ -272,19 +276,20 @@ def initialize_pissa(
|
||||
|
||||
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
with torch.no_grad():
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||
|
||||
# Create new LoRA matrices
|
||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
# Create new LoRA matrices
|
||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# These matrices can now be used as standard LoRA weights
|
||||
return new_up, new_down
|
||||
@@ -314,63 +319,64 @@ def convert_urae_to_standard_lora(
|
||||
lora_down: Standard LoRA down matrix
|
||||
alpha: Appropriate alpha value for the LoRA
|
||||
"""
|
||||
# Calculate the weight delta
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
with torch.no_grad():
|
||||
# Calculate the weight delta
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# Perform SVD on the delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
|
||||
# Perform SVD on the delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# If rank is not specified, use the same rank as the trained matrices
|
||||
if rank is None:
|
||||
rank = trained_up.shape[1]
|
||||
else:
|
||||
# Ensure we don't exceed available singular values
|
||||
rank = min(rank, len(S))
|
||||
# If rank is not specified, use the same rank as the trained matrices
|
||||
if rank is None:
|
||||
rank = trained_up.shape[1]
|
||||
else:
|
||||
# Ensure we don't exceed available singular values
|
||||
rank = min(rank, len(S))
|
||||
|
||||
# Create standard LoRA matrices using top singular values
|
||||
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
|
||||
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
# Create standard LoRA matrices using top singular values
|
||||
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
|
||||
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# Method 1: Preserve the Frobenius norm of the delta
|
||||
original_effect: float = torch.norm(delta_w, p="fro").item()
|
||||
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
|
||||
# Method 1: Preserve the Frobenius norm of the delta
|
||||
original_effect: float = torch.norm(delta_w, p="fro").item()
|
||||
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
|
||||
|
||||
# The scaling factor in lora is (alpha/r), so:
|
||||
# alpha/r × ||AB|| = ||delta_W||
|
||||
# alpha = r × ||delta_W|| / ||AB||
|
||||
if unscaled_lora_effect > 0:
|
||||
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
|
||||
else:
|
||||
norm_based_alpha = 1.0 # Fallback
|
||||
# The scaling factor in lora is (alpha/r), so:
|
||||
# alpha/r × ||AB|| = ||delta_W||
|
||||
# alpha = r × ||delta_W|| / ||AB||
|
||||
if unscaled_lora_effect > 0:
|
||||
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
|
||||
else:
|
||||
norm_based_alpha = 1.0 # Fallback
|
||||
|
||||
# Method 2: If initial_alpha is provided, adjust based on rank change
|
||||
if initial_alpha is not None:
|
||||
initial_rank = trained_up.shape[1]
|
||||
# Scale alpha proportionally if rank changed
|
||||
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
|
||||
else:
|
||||
rank_adjusted_alpha = None
|
||||
# Method 2: If initial_alpha is provided, adjust based on rank change
|
||||
if initial_alpha is not None:
|
||||
initial_rank = trained_up.shape[1]
|
||||
# Scale alpha proportionally if rank changed
|
||||
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
|
||||
else:
|
||||
rank_adjusted_alpha = None
|
||||
|
||||
# Choose the appropriate alpha
|
||||
if rank_adjusted_alpha is not None:
|
||||
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
|
||||
# Cap the difference to avoid extreme values
|
||||
alpha = rank_adjusted_alpha
|
||||
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
|
||||
if norm_based_alpha > 0:
|
||||
max_factor = 5.0 # Allow up to 5x difference
|
||||
upper_bound = norm_based_alpha * max_factor
|
||||
lower_bound = norm_based_alpha / max_factor
|
||||
alpha = min(max(alpha, lower_bound), upper_bound)
|
||||
else:
|
||||
# Use norm-based alpha
|
||||
alpha = norm_based_alpha
|
||||
# Choose the appropriate alpha
|
||||
if rank_adjusted_alpha is not None:
|
||||
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
|
||||
# Cap the difference to avoid extreme values
|
||||
alpha = rank_adjusted_alpha
|
||||
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
|
||||
if norm_based_alpha > 0:
|
||||
max_factor = 5.0 # Allow up to 5x difference
|
||||
upper_bound = norm_based_alpha * max_factor
|
||||
lower_bound = norm_based_alpha / max_factor
|
||||
alpha = min(max(alpha, lower_bound), upper_bound)
|
||||
else:
|
||||
# Use norm-based alpha
|
||||
alpha = norm_based_alpha
|
||||
|
||||
# Round to a clean value for better usability
|
||||
alpha = round(alpha, 2)
|
||||
# Round to a clean value for better usability
|
||||
alpha = round(alpha, 2)
|
||||
|
||||
# Ensure alpha is positive and within reasonable bounds
|
||||
alpha = max(0.1, min(alpha, 1024.0))
|
||||
# Ensure alpha is positive and within reasonable bounds
|
||||
alpha = max(0.1, min(alpha, 1024.0))
|
||||
|
||||
return lora_up, lora_down, alpha
|
||||
|
||||
@@ -4,6 +4,41 @@ from library.network_utils import initialize_pissa
|
||||
from library.test_util import generate_synthetic_weights
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_constraints():
|
||||
# Test with different rank values
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Test with rank less than min dimension
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test with rank equal to min dimension
|
||||
lora_down = torch.nn.Linear(20, 10)
|
||||
lora_up = torch.nn.Linear(10, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_limits():
|
||||
# Test rank limits
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
|
||||
# Test minimum rank (should work)
|
||||
lora_down_min = torch.nn.Linear(10, 1)
|
||||
lora_up_min = torch.nn.Linear(1, 5)
|
||||
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
||||
|
||||
# Test maximum rank (rank = min(input_dim, output_dim))
|
||||
max_rank = min(10, 5)
|
||||
lora_down_max = torch.nn.Linear(10, max_rank)
|
||||
lora_up_max = torch.nn.Linear(max_rank, 5)
|
||||
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
||||
|
||||
|
||||
def test_initialize_pissa_basic():
|
||||
# Create a simple linear layer
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
@@ -31,23 +66,129 @@ def test_initialize_pissa_basic():
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_constraints():
|
||||
# Test with different rank values
|
||||
def test_initialize_pissa_with_ipca():
|
||||
# Test with IncrementalPCA option
|
||||
org_module = torch.nn.Linear(100, 50) # Larger dimensions to test IPCA
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(100, 8)
|
||||
lora_up = torch.nn.Linear(8, 50)
|
||||
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call with IPCA enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=8, use_ipca=True)
|
||||
|
||||
# Verify weights are changed
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
# Check that LoRA matrices have appropriate shapes
|
||||
assert lora_down.weight.shape == torch.Size([8, 100])
|
||||
assert lora_up.weight.shape == torch.Size([50, 8])
|
||||
|
||||
|
||||
def test_initialize_pissa_with_lowrank():
|
||||
# Test with low-rank SVD option
|
||||
org_module = torch.nn.Linear(50, 30)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(50, 5)
|
||||
lora_up = torch.nn.Linear(5, 30)
|
||||
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call with low-rank SVD enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True)
|
||||
|
||||
# Verify weights are changed
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_with_lowrank_seed():
|
||||
# Test reproducibility with seed
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
# First run with seed
|
||||
lora_down1 = torch.nn.Linear(20, 3)
|
||||
lora_up1 = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down1, lora_up1, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42)
|
||||
|
||||
# Test with rank less than min dimension
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
result1_down = lora_down1.weight.data.clone()
|
||||
result1_up = lora_up1.weight.data.clone()
|
||||
|
||||
# Test with rank equal to min dimension
|
||||
lora_down = torch.nn.Linear(20, 10)
|
||||
lora_up = torch.nn.Linear(10, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
||||
# Reset module
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
# Second run with same seed
|
||||
lora_down2 = torch.nn.Linear(20, 3)
|
||||
lora_up2 = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down2, lora_up2, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42)
|
||||
|
||||
# Results should be identical
|
||||
torch.testing.assert_close(result1_down, lora_down2.weight.data)
|
||||
torch.testing.assert_close(result1_up, lora_up2.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_ipca_with_lowrank():
|
||||
# Test IncrementalPCA with low-rank SVD enabled
|
||||
org_module = torch.nn.Linear(200, 100) # Larger dimensions
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(200, 10)
|
||||
lora_up = torch.nn.Linear(10, 100)
|
||||
|
||||
# Call with both IPCA and low-rank enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10, use_ipca=True, use_lowrank=True, lowrank_q=20)
|
||||
|
||||
# Check shapes of resulting matrices
|
||||
assert lora_down.weight.shape == torch.Size([10, 200])
|
||||
assert lora_up.weight.shape == torch.Size([100, 10])
|
||||
|
||||
|
||||
def test_initialize_pissa_custom_lowrank_params():
|
||||
# Test with custom low-rank parameters
|
||||
org_module = torch.nn.Linear(30, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(30, 5)
|
||||
lora_up = torch.nn.Linear(5, 20)
|
||||
|
||||
# Test with custom q value and iterations
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=5, use_lowrank=True, lowrank_q=12, lowrank_niter=6)
|
||||
|
||||
# Check basic validity
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
|
||||
|
||||
def test_initialize_pissa_device_handling():
|
||||
# Test different device scenarios
|
||||
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
||||
|
||||
for device in devices:
|
||||
# Create modules on specific device
|
||||
org_module = torch.nn.Linear(10, 5).to(device)
|
||||
lora_down = torch.nn.Linear(10, 2).to(device)
|
||||
lora_up = torch.nn.Linear(2, 5).to(device)
|
||||
|
||||
# Test initialization with explicit device
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
||||
|
||||
# Verify modules are on the correct device
|
||||
assert org_module.weight.data.device.type == device.type
|
||||
assert lora_down.weight.data.device.type == device.type
|
||||
assert lora_up.weight.data.device.type == device.type
|
||||
|
||||
# Test with IPCA
|
||||
if device.type == "cpu": # IPCA might be slow on CPU for large matrices
|
||||
org_module_small = torch.nn.Linear(20, 10).to(device)
|
||||
lora_down_small = torch.nn.Linear(20, 3).to(device)
|
||||
lora_up_small = torch.nn.Linear(3, 10).to(device)
|
||||
|
||||
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device, use_ipca=True)
|
||||
|
||||
assert org_module_small.weight.data.device.type == device.type
|
||||
|
||||
|
||||
def test_initialize_pissa_shape_mismatch():
|
||||
@@ -118,25 +259,6 @@ def test_initialize_pissa_svd_properties():
|
||||
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_device_handling():
|
||||
# Test different device scenarios
|
||||
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
||||
|
||||
for device in devices:
|
||||
# Create modules on specific device
|
||||
org_module = torch.nn.Linear(10, 5).to(device)
|
||||
lora_down = torch.nn.Linear(10, 2).to(device)
|
||||
lora_up = torch.nn.Linear(2, 5).to(device)
|
||||
|
||||
# Test initialization with explicit device
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
||||
|
||||
# Verify modules are on the correct device
|
||||
assert org_module.weight.data.device.type == device.type
|
||||
assert lora_down.weight.data.device.type == device.type
|
||||
assert lora_up.weight.data.device.type == device.type
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype_preservation():
|
||||
# Test dtype preservation and conversion
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
@@ -152,59 +274,62 @@ def test_initialize_pissa_dtype_preservation():
|
||||
assert lora_down.weight.dtype == dtype
|
||||
assert lora_up.weight.dtype == dtype
|
||||
|
||||
# Test with explicit dtype
|
||||
if dtype != torch.float16: # Skip float16 for computational stability in SVD
|
||||
org_module2 = torch.nn.Linear(10, 5).to(dtype=torch.float32)
|
||||
lora_down2 = torch.nn.Linear(10, 2).to(dtype=torch.float32)
|
||||
lora_up2 = torch.nn.Linear(2, 5).to(dtype=torch.float32)
|
||||
|
||||
def test_initialize_pissa_rank_limits():
|
||||
# Test rank limits
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype)
|
||||
|
||||
# Test minimum rank (should work)
|
||||
lora_down_min = torch.nn.Linear(10, 1)
|
||||
lora_up_min = torch.nn.Linear(1, 5)
|
||||
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
||||
# Original module should be converted to specified dtype
|
||||
assert org_module2.weight.dtype == dtype
|
||||
|
||||
# Test maximum rank (rank = min(input_dim, output_dim))
|
||||
max_rank = min(10, 5)
|
||||
lora_down_max = torch.nn.Linear(10, max_rank)
|
||||
lora_up_max = torch.nn.Linear(max_rank, 5)
|
||||
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
||||
|
||||
|
||||
def test_initialize_pissa_numerical_stability():
|
||||
# Test with numerically challenging scenarios
|
||||
scenarios = [
|
||||
torch.randn(20, 10) * 1e-10, # Very small values
|
||||
torch.randn(20, 10) * 1e10, # Very large values
|
||||
torch.zeros(20, 10), # Zero matrix
|
||||
torch.randn(20, 10) * 1e-5, # Small values
|
||||
torch.randn(20, 10) * 1e5, # Large values
|
||||
torch.ones(20, 10), # Uniform values
|
||||
]
|
||||
|
||||
for i, weight_matrix in enumerate(scenarios):
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = weight_matrix
|
||||
|
||||
lora_down = torch.nn.Linear(10, 3)
|
||||
lora_up = torch.nn.Linear(3, 20)
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test IPCA as well
|
||||
lora_down_ipca = torch.nn.Linear(20, 3)
|
||||
lora_up_ipca = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3, use_ipca=True)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
||||
|
||||
|
||||
def test_initialize_pissa_scale_effects():
|
||||
# Test different scaling factors
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
# Test effect of different scaling factors
|
||||
org_module = torch.nn.Linear(15, 10)
|
||||
original_weight = torch.randn_like(org_module.weight.data)
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
test_scales = [0.0, 0.1, 0.5, 1.0]
|
||||
# Try different scales
|
||||
scales = [0.0, 0.01, 0.1, 1.0]
|
||||
|
||||
for scale in test_scales:
|
||||
# Reset module for each test
|
||||
for scale in scales:
|
||||
# Reset to original weights
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
lora_down = torch.nn.Linear(15, 4)
|
||||
lora_up = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=4)
|
||||
|
||||
# Verify weight modification proportional to scale
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
@@ -213,4 +338,84 @@ def test_initialize_pissa_scale_effects():
|
||||
if scale == 0.0:
|
||||
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
else:
|
||||
assert not torch.allclose(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
# For non-zero scales, verify the magnitude of change is proportional to scale
|
||||
assert weight_diff.abs().sum() > 0
|
||||
|
||||
# Do a second run with double the scale
|
||||
org_module2 = torch.nn.Linear(15, 10)
|
||||
org_module2.weight.data = original_weight.clone()
|
||||
|
||||
lora_down2 = torch.nn.Linear(15, 4)
|
||||
lora_up2 = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=scale * 2, rank=4)
|
||||
|
||||
weight_diff2 = original_weight - org_module2.weight.data
|
||||
|
||||
# The ratio of differences should be approximately 2
|
||||
# (allowing for numerical precision issues)
|
||||
ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10)
|
||||
assert 1.9 < ratio < 2.1
|
||||
|
||||
def test_initialize_pissa_large_matrix_performance():
|
||||
# Test with a large matrix to ensure it works well
|
||||
# This is particularly relevant for IPCA mode
|
||||
|
||||
# Skip if running on CPU to avoid long test times
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("Skipping large matrix test on CPU")
|
||||
|
||||
org_module = torch.nn.Linear(1000, 500)
|
||||
org_module.weight.data = torch.randn_like(org_module.weight.data) * 0.1
|
||||
|
||||
lora_down = torch.nn.Linear(1000, 16)
|
||||
lora_up = torch.nn.Linear(16, 500)
|
||||
|
||||
# Test standard approach
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=16)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Standard SVD failed on large matrix: {e}")
|
||||
|
||||
# Test IPCA approach
|
||||
lora_down_ipca = torch.nn.Linear(1000, 16)
|
||||
lora_up_ipca = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16, use_ipca=True)
|
||||
except Exception as e:
|
||||
pytest.fail(f"IPCA approach failed on large matrix: {e}")
|
||||
|
||||
# Test IPCA with lowrank
|
||||
lora_down_both = torch.nn.Linear(1000, 16)
|
||||
lora_up_both = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_ipca=True, use_lowrank=True)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}")
|
||||
|
||||
|
||||
def test_initialize_pissa_requires_grad_preservation():
|
||||
# Test that requires_grad property is preserved
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.requires_grad = False
|
||||
|
||||
lora_down = torch.nn.Linear(20, 4)
|
||||
lora_up = torch.nn.Linear(4, 10)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert not org_module.weight.requires_grad
|
||||
|
||||
# Test with requires_grad=True
|
||||
org_module2 = torch.nn.Linear(20, 10)
|
||||
org_module2.weight.requires_grad = True
|
||||
|
||||
initialize_pissa(org_module2, lora_down, lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert org_module2.weight.requires_grad
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user