Update initialization, add lora_util, add tests

This commit is contained in:
rockerBOO
2025-03-25 18:22:07 -04:00
parent 0bad5ae9f1
commit 0ad3b3c2bd
5 changed files with 961 additions and 134 deletions

87
library/lora_util.py Normal file
View File

@@ -0,0 +1,87 @@
import torch
import math
import warnings
from typing import Optional
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(lora_up.weight)
# URAE: Ultra-Resolution Adaptation with Ease
def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
device = device if device is not None else lora_down.weight.data.device
assert isinstance(device, torch.device), f"Invalid device type: {device}"
weight = org_module.weight.data.to(device, dtype=torch.float32)
with torch.autocast(device.type, dtype=torch.float32):
# SVD decomposition
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
Vr = V[:, -rank:]
Sr = S[-rank:]
Sr /= rank
Uhr = Uh[-rank:, :]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Verify shapes match expected
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
weight_dtype = org_module.weight.data.dtype
org_module.weight.data = weight.to(dtype=weight_dtype)
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
weight_dtype = org_module.weight.data.dtype
device = device if device is not None else lora_down.weight.data.device
assert isinstance(device, torch.device), f"Invalid device type: {device}"
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
with torch.autocast(device.type, dtype=torch.float32):
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, : rank]
Sr = S[: rank]
Sr /= rank
Uhr = Uh[: rank]
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Verify shapes match expected or reshape appropriately
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
lora_down.weight.data = down.to(dtype=lora_up.weight.dtype)
weight = weight.data - scale * (up @ down)
org_module.weight.data = weight.to(dtype=weight_dtype)

View File

@@ -1,5 +1,4 @@
# common functions for training
import argparse
import ast
import asyncio
@@ -6490,83 +6489,6 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# endregion
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(lora_up.weight)
# URAE: Ultra-Resolution Adaptation with Ease
def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None):
weight_dtype = org_module.weight.data.dtype
weight = org_module.weight.data.to(device="cuda", dtype=torch.float32)
# SVD decomposition
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
Vr = V[:, -rank:]
Sr = S[-rank:]
Sr /= rank
Uhr = Uh[-rank:, :]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Verify shapes match expected
if down.shape != expected_down_shape:
print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")
if up.shape != expected_up_shape:
print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
org_module.weight.data = weight.to(dtype=weight_dtype)
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None):
weight_dtype = org_module.weight.data.dtype
weight = org_module.weight.data.to(device="cuda", dtype=torch.float32)
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, : rank]
Sr = S[: rank]
Sr /= rank
Uhr = Uh[: rank]
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up= Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Verify shapes match expected or reshape appropriately
if down.shape != expected_down_shape:
print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")
# Additional reshaping logic if needed
if up.shape != expected_up_shape:
print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")
# Additional reshaping logic if needed
lora_up.weight.data = up
lora_down.weight.data = down
weight = weight.data - scale * (up @ down)
org_module.weight.data = weight.to(dtype=weight_dtype)
# collate_fn用 epoch,stepはmultiprocessing.Value
class collator_class:
def __init__(self, epoch, step, dataset):

View File

@@ -18,7 +18,7 @@ from torch import Tensor
from tqdm import tqdm
import re
from library.utils import setup_logging
from library.train_util import initialize_lora, initialize_pissa, initialize_urae
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
setup_logging()
import logging
@@ -38,7 +38,7 @@ class LoRAModule(torch.nn.Module):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
org_module: torch.nn.Linear,
multiplier=1.0,
lora_dim=4,
alpha=1,
@@ -56,68 +56,32 @@ class LoRAModule(torch.nn.Module):
super().__init__()
self.lora_name = lora_name
if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features
in_dim = org_module.in_features
out_dim = org_module.out_features
self.lora_dim = lora_dim
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
if isinstance(alpha, torch.Tensor):
alpha = alpha.detach().float().item() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
self.split_dims = split_dims
self.initialize = initialize
if split_dims is None:
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
if initialize == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
else:
initialize_lora(self.lora_down, self.lora_up)
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
else:
# conv2d not supported
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
# print(f"split_dims: {split_dims}")
self.lora_down = torch.nn.ModuleList(
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
)
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
if initialize == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
else:
initialize_lora(lora_down, lora_up)
with torch.autocast(org_module.weight.device.type), torch.no_grad():
self.initialize_weights(org_module)
# same as microsoft's
self.multiplier = multiplier
@@ -126,12 +90,44 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
def initialize_weights(self, org_module: torch.nn.Module):
if self.split_dims is None:
if self.initialize == "urae":
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
elif self.initialize == "pissa":
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = self.lora_up.weight.data.detach().clone()
self._org_lora_down = self.lora_down.weight.data.detach().clone()
else:
initialize_lora(self.lora_down, self.lora_up)
else:
assert isinstance(self.lora_down, torch.nn.ModuleList)
assert isinstance(self.lora_up, torch.nn.ModuleList)
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
if self.initialize == "urae":
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
elif self.initialize == "pissa":
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
# Need to store the original weights so we can get a plain LoRA out
self._org_lora_up = lora_up.weight.data.detach().clone()
self._org_lora_down = lora_down.weight.data.detach().clone()
else:
initialize_lora(lora_down, lora_up)
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
def forward(self, x) -> torch.Tensor:
org_forwarded = self.org_forward(x)
# module dropout
@@ -175,10 +171,6 @@ class LoRAModule(torch.nn.Module):
if self.rank_dropout is not None and self.training:
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
for i in range(len(lxs)):
if len(lx.size()) == 3:
masks[i] = masks[i].unsqueeze(1)
elif len(lx.size()) == 4:
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
lxs[i] = lxs[i] * masks[i]
# scaling for rank dropout: treat as if the rank is changed
@@ -765,6 +757,7 @@ class LoRANetwork(torch.nn.Module):
# 毎回すべてのモジュールを作るのは無駄なので要検討
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []
text_encoders = text_encoders if isinstance(text_encoders, list) else [text_encoders]
for i, text_encoder in enumerate(text_encoders):
index = i
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
@@ -1103,8 +1096,7 @@ class LoRANetwork(torch.nn.Module):
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
# We need to create new low-rank matrices that represent this delta
# One approach is to do SVD on delta_w
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
# Take the top 2*r singular values (as suggested in the paper)
rank = rank * 2
@@ -1124,7 +1116,8 @@ class LoRANetwork(torch.nn.Module):
lora_down_key = f"{lora.lora_name}.lora_down.weight"
lora_up = state_dict[lora_up_key]
lora_down = state_dict[lora_down_key]
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
with torch.autocast("cuda"):
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
state_dict[lora_up_key] = up.detach()
state_dict[lora_down_key] = down.detach()
progress.update(1)

View File

@@ -0,0 +1,216 @@
import torch
import pytest
from library.lora_util import initialize_pissa
from tests.test_util import generate_synthetic_weights
def test_initialize_pissa_basic():
# Create a simple linear layer
org_module = torch.nn.Linear(10, 5)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
torch.nn.init.xavier_uniform_(org_module.weight)
torch.nn.init.zeros_(org_module.bias)
# Create LoRA layers with matching shapes
lora_down = torch.nn.Linear(10, 2)
lora_up = torch.nn.Linear(2, 5)
# Store original weight for comparison
original_weight = org_module.weight.data.clone()
# Call the initialization function
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
# Verify basic properties
assert lora_down.weight.data is not None
assert lora_up.weight.data is not None
assert org_module.weight.data is not None
# Check that the weights have been modified
assert not torch.equal(original_weight, org_module.weight.data)
def test_initialize_pissa_rank_constraints():
# Test with different rank values
org_module = torch.nn.Linear(20, 10)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
torch.nn.init.xavier_uniform_(org_module.weight)
torch.nn.init.zeros_(org_module.bias)
# Test with rank less than min dimension
lora_down = torch.nn.Linear(20, 3)
lora_up = torch.nn.Linear(3, 10)
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
# Test with rank equal to min dimension
lora_down = torch.nn.Linear(20, 10)
lora_up = torch.nn.Linear(10, 10)
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
def test_initialize_pissa_shape_mismatch():
# Test with shape mismatch to ensure warning is printed
org_module = torch.nn.Linear(20, 10)
# Intentionally mismatched shapes to test warning mechanism
lora_down = torch.nn.Linear(20, 5) # Different shape
lora_up = torch.nn.Linear(3, 15) # Different shape
with pytest.warns(UserWarning):
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
def test_initialize_pissa_scaling():
# Test different scaling factors
scales = [0.0, 0.1, 1.0]
for scale in scales:
org_module = torch.nn.Linear(10, 5)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
original_weight = org_module.weight.data.clone()
lora_down = torch.nn.Linear(10, 2)
lora_up = torch.nn.Linear(2, 5)
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
# Check that the weight modification follows the scaling
weight_diff = original_weight - org_module.weight.data
expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data)
torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4)
def test_initialize_pissa_dtype():
# Test with different data types
dtypes = [torch.float16, torch.float32, torch.float64]
for dtype in dtypes:
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
lora_down = torch.nn.Linear(10, 2)
lora_up = torch.nn.Linear(2, 5)
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
# Verify output dtype matches input
assert org_module.weight.dtype == dtype
def test_initialize_pissa_svd_properties():
# Verify SVD decomposition properties
org_module = torch.nn.Linear(20, 10)
lora_down = torch.nn.Linear(20, 3)
lora_up = torch.nn.Linear(3, 10)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
original_weight = org_module.weight.data.clone()
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
# Reconstruct the weight
reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data)
# Check reconstruction is close to original
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
def test_initialize_pissa_device_handling():
# Test different device scenarios
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
for device in devices:
# Create modules on specific device
org_module = torch.nn.Linear(10, 5).to(device)
lora_down = torch.nn.Linear(10, 2).to(device)
lora_up = torch.nn.Linear(2, 5).to(device)
# Test initialization with explicit device
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
# Verify modules are on the correct device
assert org_module.weight.data.device.type == device.type
assert lora_down.weight.data.device.type == device.type
assert lora_up.weight.data.device.type == device.type
def test_initialize_pissa_dtype_preservation():
# Test dtype preservation and conversion
dtypes = [torch.float16, torch.float32, torch.float64]
for dtype in dtypes:
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
lora_down = torch.nn.Linear(10, 2).to(dtype=dtype)
lora_up = torch.nn.Linear(2, 5).to(dtype=dtype)
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
assert org_module.weight.dtype == dtype
assert lora_down.weight.dtype == dtype
assert lora_up.weight.dtype == dtype
def test_initialize_pissa_rank_limits():
# Test rank limits
org_module = torch.nn.Linear(10, 5)
# Test minimum rank (should work)
lora_down_min = torch.nn.Linear(10, 1)
lora_up_min = torch.nn.Linear(1, 5)
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
# Test maximum rank (rank = min(input_dim, output_dim))
max_rank = min(10, 5)
lora_down_max = torch.nn.Linear(10, max_rank)
lora_up_max = torch.nn.Linear(max_rank, 5)
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
def test_initialize_pissa_numerical_stability():
# Test with numerically challenging scenarios
scenarios = [
torch.randn(20, 10) * 1e-10, # Very small values
torch.randn(20, 10) * 1e10, # Very large values
torch.zeros(20, 10), # Zero matrix
]
for i, weight_matrix in enumerate(scenarios):
org_module = torch.nn.Linear(20, 10)
org_module.weight.data = weight_matrix
lora_down = torch.nn.Linear(10, 3)
lora_up = torch.nn.Linear(3, 20)
try:
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
except Exception as e:
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
def test_initialize_pissa_scale_effects():
# Test different scaling factors
org_module = torch.nn.Linear(10, 5)
original_weight = org_module.weight.data.clone()
test_scales = [0.0, 0.1, 0.5, 1.0]
for scale in test_scales:
# Reset module for each test
org_module.weight.data = original_weight.clone()
lora_down = torch.nn.Linear(10, 2)
lora_up = torch.nn.Linear(2, 5)
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
# Verify weight modification proportional to scale
weight_diff = original_weight - org_module.weight.data
# Approximate check of scaling effect
if scale == 0.0:
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
else:
assert not torch.allclose(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)

View File

@@ -0,0 +1,609 @@
import pytest
import torch
import torch.nn as nn
from networks.lora_flux import LoRAModule, LoRANetwork, create_network
from tests.test_util import generate_synthetic_weights
from unittest.mock import MagicMock
def test_basic_linear_module_initialization():
# Test basic Linear module initialization
org_module = nn.Linear(10, 20)
lora_module = LoRAModule(lora_name="test_linear", org_module=org_module, lora_dim=4)
# Check basic attributes
assert lora_module.lora_name == "test_linear"
assert lora_module.lora_dim == 4
# Check LoRA layers
assert isinstance(lora_module.lora_down, nn.Linear)
assert isinstance(lora_module.lora_up, nn.Linear)
# Check input and output dimensions
assert lora_module.lora_down.in_features == 10
assert lora_module.lora_down.out_features == 4
assert lora_module.lora_up.in_features == 4
assert lora_module.lora_up.out_features == 20
def test_split_dims_initialization():
# Test initialization with split_dims
org_module = nn.Linear(10, 15)
lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, split_dims=[5, 5, 5])
# Check split_dims specific attributes
assert lora_module.split_dims == [5, 5, 5]
assert isinstance(lora_module.lora_down, nn.ModuleList)
assert isinstance(lora_module.lora_up, nn.ModuleList)
# Check number of split modules
assert len(lora_module.lora_down) == 3
assert len(lora_module.lora_up) == 3
# Check dimensions of split modules
for down, up in zip(lora_module.lora_down, lora_module.lora_up):
assert down.in_features == 10
assert down.out_features == 4
assert up.in_features == 4
assert up.out_features in [5, 5, 5]
def test_alpha_scaling():
# Test alpha scaling
org_module = nn.Linear(10, 20)
# Default alpha (should be equal to lora_dim)
lora_module1 = LoRAModule(lora_name="test_alpha1", org_module=org_module, lora_dim=4, alpha=0)
assert lora_module1.scale == 1.0
# Custom alpha
lora_module2 = LoRAModule(lora_name="test_alpha2", org_module=org_module, lora_dim=4, alpha=2)
assert lora_module2.scale == 0.5
def test_initialization_methods():
# Test different initialization methods
org_module = nn.Linear(10, 20)
org_module.weight.data = generate_synthetic_weights(org_module.weight)
# Default initialization
lora_module1 = LoRAModule(lora_name="test_init_default", org_module=org_module, lora_dim=4)
assert lora_module1.lora_down.weight.shape == (4, 10)
assert lora_module1.lora_up.weight.shape == (20, 4)
# URAE initialization
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
assert lora_module2.lora_down.weight.shape == (4, 10)
assert lora_module2.lora_up.weight.shape == (20, 4)
# PISSA initialization
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None
assert lora_module3.lora_down.weight.shape == (4, 10)
assert lora_module3.lora_up.weight.shape == (20, 4)
@torch.no_grad()
def test_forward_basic_linear():
# Create a basic linear module
org_module = nn.Linear(10, 20)
org_module.weight.data = torch.testing.make_tensor(
org_module.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
lora_module = LoRAModule(lora_name="test_forward", org_module=org_module, lora_dim=4, alpha=4, multiplier=1.0)
lora_module.apply_to()
assert isinstance(lora_module.lora_down, nn.Linear)
assert isinstance(lora_module.lora_up, nn.Linear)
lora_module.lora_down.weight.data = torch.testing.make_tensor(
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
lora_module.lora_up.weight.data = torch.testing.make_tensor(
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
# Create input
x = torch.ones(5, 10)
# Perform forward pass
output = lora_module.forward(x)
# Structural assertions
assert output is not None, "Output should not be None"
assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor"
# Shape assertions
assert output.shape == (5, 20), "Output shape should match expected dimensions"
# Type and device assertions
assert output.dtype == torch.float32, "Output should be float32"
assert output.device == x.device, "Output should be on the same device as input"
def test_forward_module_dropout():
# Create a basic linear module
org_module = nn.Linear(10, 20)
lora_module = LoRAModule(
lora_name="test_module_dropout",
org_module=org_module,
lora_dim=4,
multiplier=1.0,
module_dropout=1.0, # Always drop
)
lora_module.apply_to()
# Create input
x = torch.ones(5, 10)
# Enable training mode
lora_module.train()
# Perform forward pass
output = lora_module.forward(x)
# Check if output is same as original module output
org_output = org_module(x)
torch.testing.assert_close(output, org_output)
def test_forward_rank_dropout():
# Create a basic linear module
org_module = nn.Linear(10, 20)
lora_module = LoRAModule(
lora_name="test_rank_dropout",
org_module=org_module,
lora_dim=4,
multiplier=1.0,
rank_dropout=0.5, # 50% dropout
)
lora_module.apply_to()
assert isinstance(lora_module.lora_down, nn.Linear)
assert isinstance(lora_module.lora_up, nn.Linear)
# Make lora weights predictable
lora_module.lora_down.weight.data = torch.testing.make_tensor(
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
lora_module.lora_up.weight.data = torch.testing.make_tensor(
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
# Create input
x = torch.ones(5, 10)
# Enable training mode
lora_module.train()
# Perform multiple forward passes to show dropout effect
outputs = [lora_module.forward(x) for _ in range(10)]
# Check that outputs are not all identical due to rank dropout
differences = [
torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs))
]
assert not all(differences)
def test_forward_split_dims():
# Create a basic linear module with split dimensions
org_module = nn.Linear(10, 15)
lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, multiplier=1.0, split_dims=[5, 5, 5])
lora_module.apply_to()
assert isinstance(lora_module.lora_down, nn.ModuleList)
assert isinstance(lora_module.lora_up, nn.ModuleList)
# Make lora weights predictable
for down in lora_module.lora_down:
assert isinstance(down, nn.Linear)
down.weight.data = torch.testing.make_tensor(down.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0)
for up in lora_module.lora_up:
assert isinstance(up, nn.Linear)
up.weight.data = torch.testing.make_tensor(up.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0)
# Create input
x = torch.ones(5, 10)
# Perform forward pass
output = lora_module.forward(x)
# Check output dimensions
assert output.shape == (5, 15)
def test_forward_dropout():
# Create a basic linear module
org_module = nn.Linear(10, 20)
lora_module = LoRAModule(
lora_name="test_dropout",
org_module=org_module,
lora_dim=4,
multiplier=1.0,
dropout=0.5, # 50% dropout
)
lora_module.apply_to()
assert isinstance(lora_module.lora_down, nn.Linear)
assert isinstance(lora_module.lora_up, nn.Linear)
# Make lora weights predictable
lora_module.lora_down.weight.data = torch.testing.make_tensor(
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
lora_module.lora_up.weight.data = torch.testing.make_tensor(
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
)
# Create input
x = torch.ones(5, 10)
# Enable training mode
lora_module.train()
# Perform multiple forward passes to show dropout effect
outputs = [lora_module.forward(x) for _ in range(10)]
# Check that outputs are not all identical due to dropout
differences = [
torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs))
]
assert not all(differences)
def test_create_network_default_parameters(mock_text_encoder, mock_flux):
# Mock dependencies
mock_ae = MagicMock()
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
# Call the function with minimal parameters
network = create_network(
multiplier=1.0, network_dim=None, network_alpha=None, ae=mock_ae, text_encoders=mock_text_encoders, flux=mock_flux
)
# Assertions
assert network is not None
assert network.multiplier == 1.0
assert network.lora_dim == 4 # default network_dim
assert network.alpha == 1.0 # default network_alpha
@pytest.fixture
def mock_text_encoder():
class CLIPAttention(nn.Module):
def __init__(self):
super().__init__()
# Add some dummy layers to simulate a CLIPAttention
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
class MockTextEncoder(nn.Module):
def __init__(self):
super().__init__()
# Add some dummy layers to simulate a CLIPTextModel
self.attns = torch.nn.ModuleList([CLIPAttention() for _ in range(3)])
return MockTextEncoder()
@pytest.fixture
def mock_flux():
class DoubleStreamBlock(nn.Module):
def __init__(self):
super().__init__()
# Add some dummy layers to simulate a DoubleStreamBlock
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
class SingleStreamBlock(nn.Module):
def __init__(self):
super().__init__()
# Add some dummy layers to simulate a SingleStreamBlock
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
class MockFlux(torch.nn.Module):
def __init__(self):
super().__init__()
# Add some dummy layers to simulate a Flux
self.double_blocks = torch.nn.ModuleList([DoubleStreamBlock() for _ in range(3)])
self.single_blocks = torch.nn.ModuleList([SingleStreamBlock() for _ in range(3)])
return MockFlux()
def test_create_network_custom_parameters(mock_text_encoder, mock_flux):
# Mock dependencies
mock_ae = MagicMock()
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
# Prepare custom parameters
custom_params = {
"conv_dim": 8,
"conv_alpha": 0.5,
"img_attn_dim": 16,
"txt_attn_dim": 16,
"neuron_dropout": 0.1,
"rank_dropout": 0.2,
"module_dropout": 0.3,
"train_blocks": "double",
"split_qkv": "True",
"train_t5xxl": "True",
"in_dims": "[64, 32, 16, 8, 4]",
"verbose": "True",
}
# Call the function with custom parameters
network = create_network(
multiplier=1.5,
network_dim=8,
network_alpha=2.0,
ae=mock_ae,
text_encoders=mock_text_encoders,
flux=mock_flux,
**custom_params,
)
# Assertions
assert network is not None
assert network.multiplier == 1.5
assert network.lora_dim == 8
assert network.alpha == 2.0
assert network.conv_lora_dim == 8
assert network.conv_alpha == 0.5
assert network.train_blocks == "double"
assert network.split_qkv is True
assert network.train_t5xxl is True
def test_create_network_block_indices(mock_text_encoder, mock_flux):
# Mock dependencies
mock_ae = MagicMock()
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
# Test block indices parsing
network = create_network(
multiplier=1.0,
network_dim=4,
network_alpha=1.0,
ae=mock_ae,
text_encoders=mock_text_encoders,
flux=mock_flux,
neuron_dropout=None,
**{"train_double_block_indices": "0-2,4", "train_single_block_indices": "1,3"},
)
# Assertions would depend on the exact implementation of parsing
assert network.train_double_block_indices is not None
assert network.train_single_block_indices is not None
double_block_indices = [
True,
True,
True,
False,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
]
single_block_indices = [
False,
True,
False,
True,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
False,
]
assert network.train_double_block_indices == double_block_indices
assert network.train_single_block_indices == single_block_indices
def test_create_network_loraplus_ratios(mock_text_encoder, mock_flux):
# Mock dependencies
mock_ae = MagicMock()
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
# Test LoRA+ ratios
network = create_network(
multiplier=1.0,
network_dim=4,
network_alpha=1.0,
ae=mock_ae,
text_encoders=mock_text_encoders,
flux=mock_flux,
neuron_dropout=None,
**{"loraplus_lr_ratio": 2.0, "loraplus_unet_lr_ratio": 1.5, "loraplus_text_encoder_lr_ratio": 1.0},
)
# Verify LoRA+ ratios were set correctly
assert network.loraplus_lr_ratio == 2.0
assert network.loraplus_unet_lr_ratio == 1.5
assert network.loraplus_text_encoder_lr_ratio == 1.0
def test_create_network_loraplus_default_ratio(mock_text_encoder, mock_flux):
# Mock dependencies
mock_ae = MagicMock()
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
# Test when only global LoRA+ ratio is provided
network = create_network(
multiplier=1.0,
network_dim=4,
network_alpha=1.0,
ae=mock_ae,
text_encoders=mock_text_encoders,
flux=mock_flux,
nueral_dropout=None,
**{"loraplus_lr_ratio": 2.0},
)
# Verify only global ratio is set
assert network.loraplus_lr_ratio == 2.0
assert network.loraplus_unet_lr_ratio is None
assert network.loraplus_text_encoder_lr_ratio is None
def test_create_network_invalid_inputs(mock_text_encoder, mock_flux):
# Mock dependencies
mock_ae = MagicMock()
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
mock_flux = mock_flux
# Test invalid train_blocks
with pytest.raises(AssertionError):
create_network(
multiplier=1.0,
network_dim=4,
network_alpha=1.0,
ae=mock_ae,
text_encoders=mock_text_encoders,
flux=mock_flux,
neuron_dropout=None,
**{"train_blocks": "invalid"},
)
# Test invalid in_dims
with pytest.raises(AssertionError):
create_network(
multiplier=1.0,
network_dim=4,
network_alpha=1.0,
ae=mock_ae,
text_encoders=mock_text_encoders,
flux=mock_flux,
neuron_dropout=None,
**{"in_dims": "[1,2,3]"}, # Should be 5 dimensions
)
def test_lora_network_initialization(mock_text_encoder, mock_flux):
# Test basic initialization with default parameters
lora_network = LoRANetwork(text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux)
# Check basic attributes
assert lora_network.multiplier == 1.0
assert lora_network.lora_dim == 4
assert lora_network.alpha == 1
assert lora_network.train_blocks == "all"
# Check LoRA modules are created
assert len(lora_network.text_encoder_loras) > 0
assert len(lora_network.unet_loras) > 0
def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_flux):
# Test initialization with custom parameters
lora_network = LoRANetwork(
text_encoders=[mock_text_encoder],
unet=mock_flux,
multiplier=0.5,
lora_dim=8,
alpha=2.0,
dropout=0.1,
rank_dropout=0.05,
module_dropout=0.02,
train_blocks="single",
split_qkv=True,
)
# Verify custom parameters are set correctly
assert lora_network.multiplier == 0.5
assert lora_network.lora_dim == 8
assert lora_network.alpha == 2.0
assert lora_network.dropout == 0.1
assert lora_network.rank_dropout == 0.05
assert lora_network.module_dropout == 0.02
assert lora_network.train_blocks == "single"
assert lora_network.split_qkv is True
def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux):
# Test initialization with custom module dimensions
modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8}
modules_alpha = {"lora_te1_attns_0_layers_0": 2, "lora_unet_double_blocks_0_layers_0": 1}
lora_network = LoRANetwork(
text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux, modules_dim=modules_dim, modules_alpha=modules_alpha
)
# [LoRAModule(
# (lora_down): Linear(in_features=10, out_features=8, bias=False)
# (lora_up): Linear(in_features=8, out_features=15, bias=False)
# (org_module): Linear(in_features=10, out_features=15, bias=True)
# )]
# [LoRAModule(
# (lora_down): Linear(in_features=10, out_features=16, bias=False)
# (lora_up): Linear(in_features=16, out_features=15, bias=False)
# (org_module): Linear(in_features=10, out_features=15, bias=True)
# )]
assert isinstance(lora_network.unet_loras[0].lora_down, torch.nn.Linear)
assert isinstance(lora_network.unet_loras[0].lora_up, torch.nn.Linear)
assert lora_network.unet_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_unet_double_blocks_0_layers_0"]
assert lora_network.unet_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_unet_double_blocks_0_layers_0"]
assert lora_network.unet_loras[0].alpha == modules_alpha["lora_unet_double_blocks_0_layers_0"]
assert isinstance(lora_network.text_encoder_loras[0].lora_down, torch.nn.Linear)
assert isinstance(lora_network.text_encoder_loras[0].lora_up, torch.nn.Linear)
assert lora_network.text_encoder_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_te1_attns_0_layers_0"]
assert lora_network.text_encoder_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_te1_attns_0_layers_0"]
assert lora_network.text_encoder_loras[0].alpha == modules_alpha["lora_te1_attns_0_layers_0"]