mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 09:18:00 +00:00
Update initialization, add lora_util, add tests
This commit is contained in:
216
tests/library/test_lora_util.py
Normal file
216
tests/library/test_lora_util.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
import pytest
|
||||
from library.lora_util import initialize_pissa
|
||||
from tests.test_util import generate_synthetic_weights
|
||||
|
||||
|
||||
def test_initialize_pissa_basic():
|
||||
# Create a simple linear layer
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Create LoRA layers with matching shapes
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
# Store original weight for comparison
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call the initialization function
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify basic properties
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
assert org_module.weight.data is not None
|
||||
|
||||
# Check that the weights have been modified
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_constraints():
|
||||
# Test with different rank values
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Test with rank less than min dimension
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test with rank equal to min dimension
|
||||
lora_down = torch.nn.Linear(20, 10)
|
||||
lora_up = torch.nn.Linear(10, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
||||
|
||||
|
||||
def test_initialize_pissa_shape_mismatch():
|
||||
# Test with shape mismatch to ensure warning is printed
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
|
||||
# Intentionally mismatched shapes to test warning mechanism
|
||||
lora_down = torch.nn.Linear(20, 5) # Different shape
|
||||
lora_up = torch.nn.Linear(3, 15) # Different shape
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
|
||||
def test_initialize_pissa_scaling():
|
||||
# Test different scaling factors
|
||||
scales = [0.0, 0.1, 1.0]
|
||||
|
||||
for scale in scales:
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
|
||||
# Check that the weight modification follows the scaling
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype():
|
||||
# Test with different data types
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify output dtype matches input
|
||||
assert org_module.weight.dtype == dtype
|
||||
|
||||
|
||||
def test_initialize_pissa_svd_properties():
|
||||
# Verify SVD decomposition properties
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Reconstruct the weight
|
||||
reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
# Check reconstruction is close to original
|
||||
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_device_handling():
|
||||
# Test different device scenarios
|
||||
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
||||
|
||||
for device in devices:
|
||||
# Create modules on specific device
|
||||
org_module = torch.nn.Linear(10, 5).to(device)
|
||||
lora_down = torch.nn.Linear(10, 2).to(device)
|
||||
lora_up = torch.nn.Linear(2, 5).to(device)
|
||||
|
||||
# Test initialization with explicit device
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
||||
|
||||
# Verify modules are on the correct device
|
||||
assert org_module.weight.data.device.type == device.type
|
||||
assert lora_down.weight.data.device.type == device.type
|
||||
assert lora_up.weight.data.device.type == device.type
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype_preservation():
|
||||
# Test dtype preservation and conversion
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
lora_down = torch.nn.Linear(10, 2).to(dtype=dtype)
|
||||
lora_up = torch.nn.Linear(2, 5).to(dtype=dtype)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
assert org_module.weight.dtype == dtype
|
||||
assert lora_down.weight.dtype == dtype
|
||||
assert lora_up.weight.dtype == dtype
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_limits():
|
||||
# Test rank limits
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
|
||||
# Test minimum rank (should work)
|
||||
lora_down_min = torch.nn.Linear(10, 1)
|
||||
lora_up_min = torch.nn.Linear(1, 5)
|
||||
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
||||
|
||||
# Test maximum rank (rank = min(input_dim, output_dim))
|
||||
max_rank = min(10, 5)
|
||||
lora_down_max = torch.nn.Linear(10, max_rank)
|
||||
lora_up_max = torch.nn.Linear(max_rank, 5)
|
||||
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
||||
|
||||
|
||||
def test_initialize_pissa_numerical_stability():
|
||||
# Test with numerically challenging scenarios
|
||||
scenarios = [
|
||||
torch.randn(20, 10) * 1e-10, # Very small values
|
||||
torch.randn(20, 10) * 1e10, # Very large values
|
||||
torch.zeros(20, 10), # Zero matrix
|
||||
]
|
||||
|
||||
for i, weight_matrix in enumerate(scenarios):
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = weight_matrix
|
||||
|
||||
lora_down = torch.nn.Linear(10, 3)
|
||||
lora_up = torch.nn.Linear(3, 20)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
||||
|
||||
|
||||
def test_initialize_pissa_scale_effects():
|
||||
# Test different scaling factors
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
test_scales = [0.0, 0.1, 0.5, 1.0]
|
||||
|
||||
for scale in test_scales:
|
||||
# Reset module for each test
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
|
||||
# Verify weight modification proportional to scale
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
|
||||
# Approximate check of scaling effect
|
||||
if scale == 0.0:
|
||||
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
else:
|
||||
assert not torch.allclose(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
609
tests/networks/test_lora_flux.py
Normal file
609
tests/networks/test_lora_flux.py
Normal file
@@ -0,0 +1,609 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from networks.lora_flux import LoRAModule, LoRANetwork, create_network
|
||||
from tests.test_util import generate_synthetic_weights
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def test_basic_linear_module_initialization():
|
||||
# Test basic Linear module initialization
|
||||
org_module = nn.Linear(10, 20)
|
||||
lora_module = LoRAModule(lora_name="test_linear", org_module=org_module, lora_dim=4)
|
||||
|
||||
# Check basic attributes
|
||||
assert lora_module.lora_name == "test_linear"
|
||||
assert lora_module.lora_dim == 4
|
||||
|
||||
# Check LoRA layers
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
# Check input and output dimensions
|
||||
assert lora_module.lora_down.in_features == 10
|
||||
assert lora_module.lora_down.out_features == 4
|
||||
assert lora_module.lora_up.in_features == 4
|
||||
assert lora_module.lora_up.out_features == 20
|
||||
|
||||
|
||||
def test_split_dims_initialization():
|
||||
# Test initialization with split_dims
|
||||
org_module = nn.Linear(10, 15)
|
||||
lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, split_dims=[5, 5, 5])
|
||||
|
||||
# Check split_dims specific attributes
|
||||
assert lora_module.split_dims == [5, 5, 5]
|
||||
assert isinstance(lora_module.lora_down, nn.ModuleList)
|
||||
assert isinstance(lora_module.lora_up, nn.ModuleList)
|
||||
|
||||
# Check number of split modules
|
||||
assert len(lora_module.lora_down) == 3
|
||||
assert len(lora_module.lora_up) == 3
|
||||
|
||||
# Check dimensions of split modules
|
||||
for down, up in zip(lora_module.lora_down, lora_module.lora_up):
|
||||
assert down.in_features == 10
|
||||
assert down.out_features == 4
|
||||
assert up.in_features == 4
|
||||
assert up.out_features in [5, 5, 5]
|
||||
|
||||
|
||||
def test_alpha_scaling():
|
||||
# Test alpha scaling
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
# Default alpha (should be equal to lora_dim)
|
||||
lora_module1 = LoRAModule(lora_name="test_alpha1", org_module=org_module, lora_dim=4, alpha=0)
|
||||
assert lora_module1.scale == 1.0
|
||||
|
||||
# Custom alpha
|
||||
lora_module2 = LoRAModule(lora_name="test_alpha2", org_module=org_module, lora_dim=4, alpha=2)
|
||||
assert lora_module2.scale == 0.5
|
||||
|
||||
|
||||
def test_initialization_methods():
|
||||
# Test different initialization methods
|
||||
org_module = nn.Linear(10, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
# Default initialization
|
||||
lora_module1 = LoRAModule(lora_name="test_init_default", org_module=org_module, lora_dim=4)
|
||||
|
||||
assert lora_module1.lora_down.weight.shape == (4, 10)
|
||||
assert lora_module1.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# URAE initialization
|
||||
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
|
||||
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
|
||||
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
|
||||
|
||||
assert lora_module2.lora_down.weight.shape == (4, 10)
|
||||
assert lora_module2.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# PISSA initialization
|
||||
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
|
||||
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
|
||||
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None
|
||||
|
||||
assert lora_module3.lora_down.weight.shape == (4, 10)
|
||||
assert lora_module3.lora_up.weight.shape == (20, 4)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_forward_basic_linear():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
org_module.weight.data = torch.testing.make_tensor(
|
||||
org_module.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
lora_module = LoRAModule(lora_name="test_forward", org_module=org_module, lora_dim=4, alpha=4, multiplier=1.0)
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
lora_module.lora_down.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
lora_module.lora_up.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Perform forward pass
|
||||
output = lora_module.forward(x)
|
||||
|
||||
# Structural assertions
|
||||
assert output is not None, "Output should not be None"
|
||||
assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor"
|
||||
|
||||
# Shape assertions
|
||||
assert output.shape == (5, 20), "Output shape should match expected dimensions"
|
||||
|
||||
# Type and device assertions
|
||||
assert output.dtype == torch.float32, "Output should be float32"
|
||||
assert output.device == x.device, "Output should be on the same device as input"
|
||||
|
||||
|
||||
def test_forward_module_dropout():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
lora_module = LoRAModule(
|
||||
lora_name="test_module_dropout",
|
||||
org_module=org_module,
|
||||
lora_dim=4,
|
||||
multiplier=1.0,
|
||||
module_dropout=1.0, # Always drop
|
||||
)
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Enable training mode
|
||||
lora_module.train()
|
||||
|
||||
# Perform forward pass
|
||||
output = lora_module.forward(x)
|
||||
|
||||
# Check if output is same as original module output
|
||||
org_output = org_module(x)
|
||||
torch.testing.assert_close(output, org_output)
|
||||
|
||||
|
||||
def test_forward_rank_dropout():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
lora_module = LoRAModule(
|
||||
lora_name="test_rank_dropout",
|
||||
org_module=org_module,
|
||||
lora_dim=4,
|
||||
multiplier=1.0,
|
||||
rank_dropout=0.5, # 50% dropout
|
||||
)
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
# Make lora weights predictable
|
||||
lora_module.lora_down.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
lora_module.lora_up.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Enable training mode
|
||||
lora_module.train()
|
||||
|
||||
# Perform multiple forward passes to show dropout effect
|
||||
outputs = [lora_module.forward(x) for _ in range(10)]
|
||||
|
||||
# Check that outputs are not all identical due to rank dropout
|
||||
differences = [
|
||||
torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs))
|
||||
]
|
||||
assert not all(differences)
|
||||
|
||||
|
||||
def test_forward_split_dims():
|
||||
# Create a basic linear module with split dimensions
|
||||
org_module = nn.Linear(10, 15)
|
||||
|
||||
lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, multiplier=1.0, split_dims=[5, 5, 5])
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.ModuleList)
|
||||
assert isinstance(lora_module.lora_up, nn.ModuleList)
|
||||
|
||||
# Make lora weights predictable
|
||||
for down in lora_module.lora_down:
|
||||
assert isinstance(down, nn.Linear)
|
||||
down.weight.data = torch.testing.make_tensor(down.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0)
|
||||
for up in lora_module.lora_up:
|
||||
assert isinstance(up, nn.Linear)
|
||||
up.weight.data = torch.testing.make_tensor(up.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Perform forward pass
|
||||
output = lora_module.forward(x)
|
||||
|
||||
# Check output dimensions
|
||||
assert output.shape == (5, 15)
|
||||
|
||||
|
||||
def test_forward_dropout():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
lora_module = LoRAModule(
|
||||
lora_name="test_dropout",
|
||||
org_module=org_module,
|
||||
lora_dim=4,
|
||||
multiplier=1.0,
|
||||
dropout=0.5, # 50% dropout
|
||||
)
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
# Make lora weights predictable
|
||||
lora_module.lora_down.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
lora_module.lora_up.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Enable training mode
|
||||
lora_module.train()
|
||||
|
||||
# Perform multiple forward passes to show dropout effect
|
||||
outputs = [lora_module.forward(x) for _ in range(10)]
|
||||
|
||||
# Check that outputs are not all identical due to dropout
|
||||
differences = [
|
||||
torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs))
|
||||
]
|
||||
assert not all(differences)
|
||||
|
||||
|
||||
def test_create_network_default_parameters(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Call the function with minimal parameters
|
||||
network = create_network(
|
||||
multiplier=1.0, network_dim=None, network_alpha=None, ae=mock_ae, text_encoders=mock_text_encoders, flux=mock_flux
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert network is not None
|
||||
assert network.multiplier == 1.0
|
||||
assert network.lora_dim == 4 # default network_dim
|
||||
assert network.alpha == 1.0 # default network_alpha
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_encoder():
|
||||
class CLIPAttention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a CLIPAttention
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
|
||||
|
||||
class MockTextEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a CLIPTextModel
|
||||
self.attns = torch.nn.ModuleList([CLIPAttention() for _ in range(3)])
|
||||
|
||||
return MockTextEncoder()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flux():
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a DoubleStreamBlock
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a SingleStreamBlock
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
|
||||
|
||||
class MockFlux(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a Flux
|
||||
self.double_blocks = torch.nn.ModuleList([DoubleStreamBlock() for _ in range(3)])
|
||||
self.single_blocks = torch.nn.ModuleList([SingleStreamBlock() for _ in range(3)])
|
||||
|
||||
return MockFlux()
|
||||
|
||||
|
||||
def test_create_network_custom_parameters(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Prepare custom parameters
|
||||
custom_params = {
|
||||
"conv_dim": 8,
|
||||
"conv_alpha": 0.5,
|
||||
"img_attn_dim": 16,
|
||||
"txt_attn_dim": 16,
|
||||
"neuron_dropout": 0.1,
|
||||
"rank_dropout": 0.2,
|
||||
"module_dropout": 0.3,
|
||||
"train_blocks": "double",
|
||||
"split_qkv": "True",
|
||||
"train_t5xxl": "True",
|
||||
"in_dims": "[64, 32, 16, 8, 4]",
|
||||
"verbose": "True",
|
||||
}
|
||||
|
||||
# Call the function with custom parameters
|
||||
network = create_network(
|
||||
multiplier=1.5,
|
||||
network_dim=8,
|
||||
network_alpha=2.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
**custom_params,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert network is not None
|
||||
assert network.multiplier == 1.5
|
||||
assert network.lora_dim == 8
|
||||
assert network.alpha == 2.0
|
||||
assert network.conv_lora_dim == 8
|
||||
assert network.conv_alpha == 0.5
|
||||
assert network.train_blocks == "double"
|
||||
assert network.split_qkv is True
|
||||
assert network.train_t5xxl is True
|
||||
|
||||
|
||||
def test_create_network_block_indices(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Test block indices parsing
|
||||
network = create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"train_double_block_indices": "0-2,4", "train_single_block_indices": "1,3"},
|
||||
)
|
||||
|
||||
# Assertions would depend on the exact implementation of parsing
|
||||
assert network.train_double_block_indices is not None
|
||||
assert network.train_single_block_indices is not None
|
||||
|
||||
double_block_indices = [
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
]
|
||||
single_block_indices = [
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
]
|
||||
|
||||
assert network.train_double_block_indices == double_block_indices
|
||||
assert network.train_single_block_indices == single_block_indices
|
||||
|
||||
|
||||
def test_create_network_loraplus_ratios(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Test LoRA+ ratios
|
||||
network = create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"loraplus_lr_ratio": 2.0, "loraplus_unet_lr_ratio": 1.5, "loraplus_text_encoder_lr_ratio": 1.0},
|
||||
)
|
||||
|
||||
# Verify LoRA+ ratios were set correctly
|
||||
assert network.loraplus_lr_ratio == 2.0
|
||||
assert network.loraplus_unet_lr_ratio == 1.5
|
||||
assert network.loraplus_text_encoder_lr_ratio == 1.0
|
||||
|
||||
|
||||
def test_create_network_loraplus_default_ratio(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Test when only global LoRA+ ratio is provided
|
||||
network = create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
nueral_dropout=None,
|
||||
**{"loraplus_lr_ratio": 2.0},
|
||||
)
|
||||
|
||||
# Verify only global ratio is set
|
||||
assert network.loraplus_lr_ratio == 2.0
|
||||
assert network.loraplus_unet_lr_ratio is None
|
||||
assert network.loraplus_text_encoder_lr_ratio is None
|
||||
|
||||
|
||||
def test_create_network_invalid_inputs(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
mock_flux = mock_flux
|
||||
|
||||
# Test invalid train_blocks
|
||||
with pytest.raises(AssertionError):
|
||||
create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"train_blocks": "invalid"},
|
||||
)
|
||||
|
||||
# Test invalid in_dims
|
||||
with pytest.raises(AssertionError):
|
||||
create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"in_dims": "[1,2,3]"}, # Should be 5 dimensions
|
||||
)
|
||||
|
||||
|
||||
def test_lora_network_initialization(mock_text_encoder, mock_flux):
|
||||
# Test basic initialization with default parameters
|
||||
lora_network = LoRANetwork(text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux)
|
||||
|
||||
# Check basic attributes
|
||||
assert lora_network.multiplier == 1.0
|
||||
assert lora_network.lora_dim == 4
|
||||
assert lora_network.alpha == 1
|
||||
assert lora_network.train_blocks == "all"
|
||||
|
||||
# Check LoRA modules are created
|
||||
assert len(lora_network.text_encoder_loras) > 0
|
||||
assert len(lora_network.unet_loras) > 0
|
||||
|
||||
|
||||
def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_flux):
|
||||
# Test initialization with custom parameters
|
||||
lora_network = LoRANetwork(
|
||||
text_encoders=[mock_text_encoder],
|
||||
unet=mock_flux,
|
||||
multiplier=0.5,
|
||||
lora_dim=8,
|
||||
alpha=2.0,
|
||||
dropout=0.1,
|
||||
rank_dropout=0.05,
|
||||
module_dropout=0.02,
|
||||
train_blocks="single",
|
||||
split_qkv=True,
|
||||
)
|
||||
|
||||
# Verify custom parameters are set correctly
|
||||
assert lora_network.multiplier == 0.5
|
||||
assert lora_network.lora_dim == 8
|
||||
assert lora_network.alpha == 2.0
|
||||
assert lora_network.dropout == 0.1
|
||||
assert lora_network.rank_dropout == 0.05
|
||||
assert lora_network.module_dropout == 0.02
|
||||
assert lora_network.train_blocks == "single"
|
||||
assert lora_network.split_qkv is True
|
||||
|
||||
|
||||
def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux):
|
||||
# Test initialization with custom module dimensions
|
||||
modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8}
|
||||
modules_alpha = {"lora_te1_attns_0_layers_0": 2, "lora_unet_double_blocks_0_layers_0": 1}
|
||||
|
||||
lora_network = LoRANetwork(
|
||||
text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux, modules_dim=modules_dim, modules_alpha=modules_alpha
|
||||
)
|
||||
|
||||
# [LoRAModule(
|
||||
# (lora_down): Linear(in_features=10, out_features=8, bias=False)
|
||||
# (lora_up): Linear(in_features=8, out_features=15, bias=False)
|
||||
# (org_module): Linear(in_features=10, out_features=15, bias=True)
|
||||
# )]
|
||||
# [LoRAModule(
|
||||
# (lora_down): Linear(in_features=10, out_features=16, bias=False)
|
||||
# (lora_up): Linear(in_features=16, out_features=15, bias=False)
|
||||
# (org_module): Linear(in_features=10, out_features=15, bias=True)
|
||||
# )]
|
||||
|
||||
assert isinstance(lora_network.unet_loras[0].lora_down, torch.nn.Linear)
|
||||
assert isinstance(lora_network.unet_loras[0].lora_up, torch.nn.Linear)
|
||||
assert lora_network.unet_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_unet_double_blocks_0_layers_0"]
|
||||
assert lora_network.unet_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_unet_double_blocks_0_layers_0"]
|
||||
assert lora_network.unet_loras[0].alpha == modules_alpha["lora_unet_double_blocks_0_layers_0"]
|
||||
|
||||
assert isinstance(lora_network.text_encoder_loras[0].lora_down, torch.nn.Linear)
|
||||
assert isinstance(lora_network.text_encoder_loras[0].lora_up, torch.nn.Linear)
|
||||
assert lora_network.text_encoder_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_te1_attns_0_layers_0"]
|
||||
assert lora_network.text_encoder_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_te1_attns_0_layers_0"]
|
||||
assert lora_network.text_encoder_loras[0].alpha == modules_alpha["lora_te1_attns_0_layers_0"]
|
||||
Reference in New Issue
Block a user