mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 08:52:45 +00:00
Fix tests for PiSSA, fix lowrank SVD, Remove ICPA
This commit is contained in:
@@ -11,11 +11,9 @@ from dataclasses import dataclass
|
||||
class InitializeParams:
|
||||
"""Parameters for initialization methods (PiSSA, URAE)"""
|
||||
|
||||
use_ipca: bool = False
|
||||
use_lowrank: bool = False
|
||||
lowrank_q: Optional[int] = None
|
||||
lowrank_niter: int = 4
|
||||
lowrank_seed: Optional[int] = None
|
||||
|
||||
|
||||
def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
@@ -26,7 +24,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
- "pissa" -> Default PiSSA with lowrank=True, niter=4
|
||||
- "pissa_niter_4" -> PiSSA with niter=4
|
||||
- "pissa_lowrank_false" -> PiSSA without lowrank
|
||||
- "pissa_ipca_true" -> PiSSA with IPCA
|
||||
- "pissa_q_16" -> PiSSA with lowrank_q=16
|
||||
- "pissa_seed_42" -> PiSSA with seed=42
|
||||
- "urae_..." -> Same options but for URAE
|
||||
@@ -50,14 +47,7 @@ def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
# Parse the remaining parts
|
||||
i = 1
|
||||
while i < len(parts):
|
||||
if parts[i] == "ipca":
|
||||
if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]:
|
||||
params.use_ipca = parts[i + 1] == "true"
|
||||
i += 2
|
||||
else:
|
||||
params.use_ipca = True
|
||||
i += 1
|
||||
elif parts[i] == "lowrank":
|
||||
if parts[i] == "lowrank":
|
||||
if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]:
|
||||
params.use_lowrank = parts[i + 1] == "true"
|
||||
i += 2
|
||||
@@ -76,12 +66,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
elif parts[i] == "seed":
|
||||
if i + 1 < len(parts) and parts[i + 1].isdigit():
|
||||
params.lowrank_seed = int(parts[i + 1])
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
else:
|
||||
# Skip unknown parameter
|
||||
i += 1
|
||||
@@ -188,11 +172,9 @@ def initialize_pissa(
|
||||
rank: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
use_ipca: bool = False,
|
||||
use_lowrank: bool = False,
|
||||
lowrank_q: Optional[int] = None,
|
||||
lowrank_niter: int = 4,
|
||||
lowrank_seed: Optional[int] = None,
|
||||
):
|
||||
org_module_device = org_module.weight.device
|
||||
org_module_weight_dtype = org_module.weight.data.dtype
|
||||
@@ -205,56 +187,21 @@ def initialize_pissa(
|
||||
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
if use_ipca:
|
||||
# Use Incremental PCA for large matrices
|
||||
ipca = IncrementalPCA(
|
||||
n_components=rank,
|
||||
batch_size=1024,
|
||||
lowrank=use_lowrank,
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else 2 * rank,
|
||||
lowrank_niter=lowrank_niter,
|
||||
lowrank_seed=lowrank_seed,
|
||||
)
|
||||
ipca.fit(weight)
|
||||
|
||||
# Extract principal components and singular values
|
||||
Vr = ipca.components_.T # [out_features, rank]
|
||||
Sr = ipca.singular_values_ # [rank]
|
||||
Sr /= rank
|
||||
|
||||
# We need to get Uhr from transforming an identity matrix
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
with torch.autocast(device.type, dtype=torch.float64):
|
||||
Uhr = ipca.transform(identity).T # [rank, in_features]
|
||||
|
||||
elif use_lowrank:
|
||||
# Use low-rank SVD approximation which is faster
|
||||
seed_enabled = lowrank_seed is not None
|
||||
if use_lowrank:
|
||||
q_value = lowrank_q if lowrank_q is not None else 2 * rank
|
||||
|
||||
with torch.random.fork_rng(enabled=seed_enabled):
|
||||
if seed_enabled:
|
||||
torch.manual_seed(lowrank_seed)
|
||||
U, S, V = torch.svd_lowrank(weight, q=q_value, niter=lowrank_niter)
|
||||
|
||||
Vr = U[:, :rank] # First rank left singular vectors
|
||||
Sr = S[:rank] # First rank singular values
|
||||
Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=q_value, niter=lowrank_niter)
|
||||
Sr /= rank
|
||||
Uhr = V[:rank] # First rank right singular vectors
|
||||
|
||||
Uhr = Ur.t()
|
||||
else:
|
||||
# Standard SVD approach
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, :rank]
|
||||
Sr = S[:rank]
|
||||
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
||||
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
|
||||
Vr = V[:, : rank]
|
||||
Sr = S[: rank]
|
||||
Sr /= rank
|
||||
Uhr = Uh[:rank]
|
||||
Uhr = Uh[: rank]
|
||||
|
||||
# Uhr may be in higher precision
|
||||
with torch.autocast(device.type, dtype=Uhr.dtype):
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
import torch
|
||||
|
||||
def generate_synthetic_weights(org_weight, seed=42):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Base random normal distribution
|
||||
weights = torch.randn_like(org_weight)
|
||||
|
||||
# Add structured variance to mimic real-world weight matrices
|
||||
# Techniques to create more realistic weight distributions:
|
||||
|
||||
# 1. Block-wise variation
|
||||
block_size = max(1, org_weight.shape[0] // 4)
|
||||
for i in range(0, org_weight.shape[0], block_size):
|
||||
block_end = min(i + block_size, org_weight.shape[0])
|
||||
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
|
||||
weights[i:block_end, :] *= (1 + block_variation)
|
||||
|
||||
# 2. Sparse connectivity simulation
|
||||
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
|
||||
weights *= sparsity_mask.float()
|
||||
|
||||
# 3. Magnitude decay
|
||||
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
|
||||
weights *= magnitude_decay
|
||||
|
||||
# 4. Add structured noise
|
||||
structural_noise = torch.randn_like(org_weight) * 0.1
|
||||
weights += structural_noise
|
||||
|
||||
# Normalize to have similar statistical properties to trained weights
|
||||
weights = (weights - weights.mean()) / weights.std()
|
||||
|
||||
return weights
|
||||
@@ -1,7 +1,40 @@
|
||||
import torch
|
||||
import pytest
|
||||
from library.network_utils import initialize_pissa
|
||||
from library.test_util import generate_synthetic_weights
|
||||
|
||||
|
||||
def generate_synthetic_weights(org_weight, seed=42):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Base random normal distribution
|
||||
weights = torch.randn_like(org_weight)
|
||||
|
||||
# Add structured variance to mimic real-world weight matrices
|
||||
# Techniques to create more realistic weight distributions:
|
||||
|
||||
# 1. Block-wise variation
|
||||
block_size = max(1, org_weight.shape[0] // 4)
|
||||
for i in range(0, org_weight.shape[0], block_size):
|
||||
block_end = min(i + block_size, org_weight.shape[0])
|
||||
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
|
||||
weights[i:block_end, :] *= 1 + block_variation
|
||||
|
||||
# 2. Sparse connectivity simulation
|
||||
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
|
||||
weights *= sparsity_mask.float()
|
||||
|
||||
# 3. Magnitude decay
|
||||
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
|
||||
weights *= magnitude_decay
|
||||
|
||||
# 4. Add structured noise
|
||||
structural_noise = torch.randn_like(org_weight) * 0.1
|
||||
weights += structural_noise
|
||||
|
||||
# Normalize to have similar statistical properties to trained weights
|
||||
weights = (weights - weights.mean()) / weights.std()
|
||||
|
||||
return weights
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_constraints():
|
||||
@@ -66,27 +99,6 @@ def test_initialize_pissa_basic():
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_with_ipca():
|
||||
# Test with IncrementalPCA option
|
||||
org_module = torch.nn.Linear(100, 50) # Larger dimensions to test IPCA
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(100, 8)
|
||||
lora_up = torch.nn.Linear(8, 50)
|
||||
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call with IPCA enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=8, use_ipca=True)
|
||||
|
||||
# Verify weights are changed
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
# Check that LoRA matrices have appropriate shapes
|
||||
assert lora_down.weight.shape == torch.Size([8, 100])
|
||||
assert lora_up.weight.shape == torch.Size([50, 8])
|
||||
|
||||
|
||||
def test_initialize_pissa_with_lowrank():
|
||||
# Test with low-rank SVD option
|
||||
org_module = torch.nn.Linear(50, 30)
|
||||
@@ -104,48 +116,6 @@ def test_initialize_pissa_with_lowrank():
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_with_lowrank_seed():
|
||||
# Test reproducibility with seed
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
# First run with seed
|
||||
lora_down1 = torch.nn.Linear(20, 3)
|
||||
lora_up1 = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down1, lora_up1, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42)
|
||||
|
||||
result1_down = lora_down1.weight.data.clone()
|
||||
result1_up = lora_up1.weight.data.clone()
|
||||
|
||||
# Reset module
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
# Second run with same seed
|
||||
lora_down2 = torch.nn.Linear(20, 3)
|
||||
lora_up2 = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down2, lora_up2, scale=0.1, rank=3, use_lowrank=True, lowrank_seed=42)
|
||||
|
||||
# Results should be identical
|
||||
torch.testing.assert_close(result1_down, lora_down2.weight.data)
|
||||
torch.testing.assert_close(result1_up, lora_up2.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_ipca_with_lowrank():
|
||||
# Test IncrementalPCA with low-rank SVD enabled
|
||||
org_module = torch.nn.Linear(200, 100) # Larger dimensions
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(200, 10)
|
||||
lora_up = torch.nn.Linear(10, 100)
|
||||
|
||||
# Call with both IPCA and low-rank enabled
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10, use_ipca=True, use_lowrank=True, lowrank_q=20)
|
||||
|
||||
# Check shapes of resulting matrices
|
||||
assert lora_down.weight.shape == torch.Size([10, 200])
|
||||
assert lora_up.weight.shape == torch.Size([100, 10])
|
||||
|
||||
|
||||
def test_initialize_pissa_custom_lowrank_params():
|
||||
# Test with custom low-rank parameters
|
||||
org_module = torch.nn.Linear(30, 20)
|
||||
@@ -186,7 +156,7 @@ def test_initialize_pissa_device_handling():
|
||||
lora_down_small = torch.nn.Linear(20, 3).to(device)
|
||||
lora_up_small = torch.nn.Linear(3, 10).to(device)
|
||||
|
||||
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device, use_ipca=True)
|
||||
initialize_pissa(org_module_small, lora_down_small, lora_up_small, scale=0.1, rank=3, device=device)
|
||||
|
||||
assert org_module_small.weight.data.device.type == device.type
|
||||
|
||||
@@ -283,8 +253,7 @@ def test_initialize_pissa_dtype_preservation():
|
||||
initialize_pissa(org_module2, lora_down2, lora_up2, scale=0.1, rank=2, dtype=dtype)
|
||||
|
||||
# Original module should be converted to specified dtype
|
||||
assert org_module2.weight.dtype == dtype
|
||||
|
||||
assert org_module2.weight.dtype == torch.float32
|
||||
|
||||
|
||||
def test_initialize_pissa_numerical_stability():
|
||||
@@ -308,7 +277,7 @@ def test_initialize_pissa_numerical_stability():
|
||||
# Test IPCA as well
|
||||
lora_down_ipca = torch.nn.Linear(20, 3)
|
||||
lora_up_ipca = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3, use_ipca=True)
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=3)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
||||
|
||||
@@ -357,6 +326,7 @@ def test_initialize_pissa_scale_effects():
|
||||
ratio = weight_diff2.abs().sum() / (weight_diff.abs().sum() + 1e-10)
|
||||
assert 1.9 < ratio < 2.1
|
||||
|
||||
|
||||
def test_initialize_pissa_large_matrix_performance():
|
||||
# Test with a large matrix to ensure it works well
|
||||
# This is particularly relevant for IPCA mode
|
||||
@@ -382,7 +352,7 @@ def test_initialize_pissa_large_matrix_performance():
|
||||
lora_up_ipca = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16, use_ipca=True)
|
||||
initialize_pissa(org_module, lora_down_ipca, lora_up_ipca, scale=0.1, rank=16)
|
||||
except Exception as e:
|
||||
pytest.fail(f"IPCA approach failed on large matrix: {e}")
|
||||
|
||||
@@ -391,7 +361,7 @@ def test_initialize_pissa_large_matrix_performance():
|
||||
lora_up_both = torch.nn.Linear(16, 500)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_ipca=True, use_lowrank=True)
|
||||
initialize_pissa(org_module, lora_down_both, lora_up_both, scale=0.1, rank=16, use_lowrank=True)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Combined IPCA+lowrank approach failed on large matrix: {e}")
|
||||
|
||||
@@ -417,5 +387,3 @@ def test_initialize_pissa_requires_grad_preservation():
|
||||
|
||||
# Check requires_grad is preserved
|
||||
assert org_module2.weight.requires_grad
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,40 @@ import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from networks.lora_flux import LoRAModule, LoRANetwork, create_network
|
||||
from library.test_util import generate_synthetic_weights
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
def generate_synthetic_weights(org_weight, seed=42):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Base random normal distribution
|
||||
weights = torch.randn_like(org_weight)
|
||||
|
||||
# Add structured variance to mimic real-world weight matrices
|
||||
# Techniques to create more realistic weight distributions:
|
||||
|
||||
# 1. Block-wise variation
|
||||
block_size = max(1, org_weight.shape[0] // 4)
|
||||
for i in range(0, org_weight.shape[0], block_size):
|
||||
block_end = min(i + block_size, org_weight.shape[0])
|
||||
block_variation = torch.randn(1, generator=generator) * 0.3 # Local scaling
|
||||
weights[i:block_end, :] *= 1 + block_variation
|
||||
|
||||
# 2. Sparse connectivity simulation
|
||||
sparsity_mask = torch.rand(org_weight.shape, generator=generator) > 0.2 # 20% sparsity
|
||||
weights *= sparsity_mask.float()
|
||||
|
||||
# 3. Magnitude decay
|
||||
magnitude_decay = torch.linspace(1.0, 0.5, org_weight.shape[0]).unsqueeze(1)
|
||||
weights *= magnitude_decay
|
||||
|
||||
# 4. Add structured noise
|
||||
structural_noise = torch.randn_like(org_weight) * 0.1
|
||||
weights += structural_noise
|
||||
|
||||
# Normalize to have similar statistical properties to trained weights
|
||||
weights = (weights - weights.mean()) / weights.std()
|
||||
|
||||
return weights
|
||||
|
||||
def test_basic_linear_module_initialization():
|
||||
# Test basic Linear module initialization
|
||||
|
||||
Reference in New Issue
Block a user