From 0ad3b3c2bde89005b7ec6c229c87eeb96dee1ba4 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Tue, 25 Mar 2025 18:22:07 -0400 Subject: [PATCH] Update initialization, add lora_util, add tests --- library/lora_util.py | 87 +++++ library/train_util.py | 78 ---- networks/lora_flux.py | 105 +++--- tests/library/test_lora_util.py | 216 +++++++++++ tests/networks/test_lora_flux.py | 609 +++++++++++++++++++++++++++++++ 5 files changed, 961 insertions(+), 134 deletions(-) create mode 100644 library/lora_util.py create mode 100644 tests/library/test_lora_util.py create mode 100644 tests/networks/test_lora_flux.py diff --git a/library/lora_util.py b/library/lora_util.py new file mode 100644 index 00000000..d76cd17b --- /dev/null +++ b/library/lora_util.py @@ -0,0 +1,87 @@ +import torch +import math +import warnings +from typing import Optional + +def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module): + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(lora_up.weight) + +# URAE: Ultra-Resolution Adaptation with Ease +def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None): + device = device if device is not None else lora_down.weight.data.device + assert isinstance(device, torch.device), f"Invalid device type: {device}" + + weight = org_module.weight.data.to(device, dtype=torch.float32) + + with torch.autocast(device.type, dtype=torch.float32): + # SVD decomposition + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + + # For URAE, use the LAST/SMALLEST singular values and vectors (residual components) + Vr = V[:, -rank:] + Sr = S[-rank:] + Sr /= rank + Uhr = Uh[-rank:, :] + + # Create down and up matrices + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up = Vr @ torch.diag(torch.sqrt(Sr)) + + # Get expected shapes + expected_down_shape = lora_down.weight.shape + expected_up_shape = lora_up.weight.shape + + # Verify shapes match expected + if down.shape != expected_down_shape: + warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")) + + if up.shape != expected_up_shape: + warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) + + # Assign to LoRA weights + lora_up.weight.data = up + lora_down.weight.data = down + + # Optionally, subtract from original weight + weight = weight - scale * (up @ down) + + weight_dtype = org_module.weight.data.dtype + org_module.weight.data = weight.to(dtype=weight_dtype) + +# PiSSA: Principal Singular Values and Singular Vectors Adaptation +def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None): + weight_dtype = org_module.weight.data.dtype + + device = device if device is not None else lora_down.weight.data.device + assert isinstance(device, torch.device), f"Invalid device type: {device}" + + weight = org_module.weight.data.clone().to(device, dtype=torch.float32) + + with torch.autocast(device.type, dtype=torch.float32): + # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + Vr = V[:, : rank] + Sr = S[: rank] + Sr /= rank + Uhr = Uh[: rank] + + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up = Vr @ torch.diag(torch.sqrt(Sr)) + + # Get expected shapes + expected_down_shape = lora_down.weight.shape + expected_up_shape = lora_up.weight.shape + + # Verify shapes match expected or reshape appropriately + if down.shape != expected_down_shape: + warnings.warn(UserWarning(f"Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")) + + if up.shape != expected_up_shape: + warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) + + lora_up.weight.data = up.to(dtype=lora_up.weight.dtype) + lora_down.weight.data = down.to(dtype=lora_up.weight.dtype) + + weight = weight.data - scale * (up @ down) + org_module.weight.data = weight.to(dtype=weight_dtype) diff --git a/library/train_util.py b/library/train_util.py index 6643b5bf..331a9a24 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1,5 +1,4 @@ # common functions for training - import argparse import ast import asyncio @@ -6490,83 +6489,6 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion - -def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module): - torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(lora_up.weight) - -# URAE: Ultra-Resolution Adaptation with Ease -def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None): - weight_dtype = org_module.weight.data.dtype - weight = org_module.weight.data.to(device="cuda", dtype=torch.float32) - - # SVD decomposition - V, S, Uh = torch.linalg.svd(weight, full_matrices=False) - - # For URAE, use the LAST/SMALLEST singular values and vectors (residual components) - Vr = V[:, -rank:] - Sr = S[-rank:] - Sr /= rank - Uhr = Uh[-rank:, :] - - # Create down and up matrices - down = torch.diag(torch.sqrt(Sr)) @ Uhr - up = Vr @ torch.diag(torch.sqrt(Sr)) - - # Get expected shapes - expected_down_shape = lora_down.weight.shape - expected_up_shape = lora_up.weight.shape - - # Verify shapes match expected - if down.shape != expected_down_shape: - print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}") - - if up.shape != expected_up_shape: - print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}") - - # Assign to LoRA weights - lora_up.weight.data = up - lora_down.weight.data = down - - # Optionally, subtract from original weight - weight = weight - scale * (up @ down) - org_module.weight.data = weight.to(dtype=weight_dtype) - -# PiSSA: Principal Singular Values and Singular Vectors Adaptation -def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None): - weight_dtype = org_module.weight.data.dtype - - weight = org_module.weight.data.to(device="cuda", dtype=torch.float32) - - # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, - V, S, Uh = torch.linalg.svd(weight, full_matrices=False) - Vr = V[:, : rank] - Sr = S[: rank] - Sr /= rank - Uhr = Uh[: rank] - - down = torch.diag(torch.sqrt(Sr)) @ Uhr - up= Vr @ torch.diag(torch.sqrt(Sr)) - - # Get expected shapes - expected_down_shape = lora_down.weight.shape - expected_up_shape = lora_up.weight.shape - - # Verify shapes match expected or reshape appropriately - if down.shape != expected_down_shape: - print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}") - # Additional reshaping logic if needed - - if up.shape != expected_up_shape: - print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}") - # Additional reshaping logic if needed - - lora_up.weight.data = up - lora_down.weight.data = down - - weight = weight.data - scale * (up @ down) - org_module.weight.data = weight.to(dtype=weight_dtype) - # collate_fn用 epoch,stepはmultiprocessing.Value class collator_class: def __init__(self, epoch, step, dataset): diff --git a/networks/lora_flux.py b/networks/lora_flux.py index b9b11423..e6b9b95b 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -18,7 +18,7 @@ from torch import Tensor from tqdm import tqdm import re from library.utils import setup_logging -from library.train_util import initialize_lora, initialize_pissa, initialize_urae +from library.lora_util import initialize_lora, initialize_pissa, initialize_urae setup_logging() import logging @@ -38,7 +38,7 @@ class LoRAModule(torch.nn.Module): def __init__( self, lora_name, - org_module: torch.nn.Module, + org_module: torch.nn.Linear, multiplier=1.0, lora_dim=4, alpha=1, @@ -56,68 +56,32 @@ class LoRAModule(torch.nn.Module): super().__init__() self.lora_name = lora_name - if org_module.__class__.__name__ == "Conv2d": - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features + in_dim = org_module.in_features + out_dim = org_module.out_features self.lora_dim = lora_dim - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + if isinstance(alpha, torch.Tensor): + alpha = alpha.detach().float().item() # without casting, bf16 causes error alpha = self.lora_dim if alpha is None or alpha == 0 else alpha self.scale = alpha / self.lora_dim self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える self.split_dims = split_dims + self.initialize = initialize if split_dims is None: - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - - if initialize == "urae": - initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = self.lora_up.weight.data.detach().clone() - self._org_lora_down = self.lora_down.weight.data.detach().clone() - elif initialize == "pissa": - initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = self.lora_up.weight.data.detach().clone() - self._org_lora_down = self.lora_down.weight.data.detach().clone() - else: - initialize_lora(self.lora_down, self.lora_up) + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) else: - # conv2d not supported assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" - assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear" - # print(f"split_dims: {split_dims}") self.lora_down = torch.nn.ModuleList( [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] ) self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - for lora_down, lora_up in zip(self.lora_down, self.lora_up): - if initialize == "urae": - initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = lora_up.weight.data.detach().clone() - self._org_lora_down = lora_down.weight.data.detach().clone() - elif initialize == "pissa": - initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = lora_up.weight.data.detach().clone() - self._org_lora_down = lora_down.weight.data.detach().clone() - else: - initialize_lora(lora_down, lora_up) + + with torch.autocast(org_module.weight.device.type), torch.no_grad(): + self.initialize_weights(org_module) # same as microsoft's self.multiplier = multiplier @@ -126,12 +90,44 @@ class LoRAModule(torch.nn.Module): self.rank_dropout = rank_dropout self.module_dropout = module_dropout + + def initialize_weights(self, org_module: torch.nn.Module): + if self.split_dims is None: + if self.initialize == "urae": + initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = self.lora_up.weight.data.detach().clone() + self._org_lora_down = self.lora_down.weight.data.detach().clone() + elif self.initialize == "pissa": + initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = self.lora_up.weight.data.detach().clone() + self._org_lora_down = self.lora_down.weight.data.detach().clone() + else: + initialize_lora(self.lora_down, self.lora_up) + else: + assert isinstance(self.lora_down, torch.nn.ModuleList) + assert isinstance(self.lora_up, torch.nn.ModuleList) + for lora_down, lora_up in zip(self.lora_down, self.lora_up): + if self.initialize == "urae": + initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = lora_up.weight.data.detach().clone() + self._org_lora_down = lora_down.weight.data.detach().clone() + elif self.initialize == "pissa": + initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = lora_up.weight.data.detach().clone() + self._org_lora_down = lora_down.weight.data.detach().clone() + else: + initialize_lora(lora_down, lora_up) + def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward del self.org_module - def forward(self, x): + def forward(self, x) -> torch.Tensor: org_forwarded = self.org_forward(x) # module dropout @@ -175,10 +171,6 @@ class LoRAModule(torch.nn.Module): if self.rank_dropout is not None and self.training: masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs] for i in range(len(lxs)): - if len(lx.size()) == 3: - masks[i] = masks[i].unsqueeze(1) - elif len(lx.size()) == 4: - masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1) lxs[i] = lxs[i] * masks[i] # scaling for rank dropout: treat as if the rank is changed @@ -765,6 +757,7 @@ class LoRANetwork(torch.nn.Module): # 毎回すべてのモジュールを作るのは無駄なので要検討 self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = [] skipped_te = [] + text_encoders = text_encoders if isinstance(text_encoders, list) else [text_encoders] for i, text_encoder in enumerate(text_encoders): index = i if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False @@ -1103,8 +1096,7 @@ class LoRANetwork(torch.nn.Module): delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) # We need to create new low-rank matrices that represent this delta - # One approach is to do SVD on delta_w - U, S, V = torch.linalg.svd(delta_w, full_matrices=False) + U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) # Take the top 2*r singular values (as suggested in the paper) rank = rank * 2 @@ -1124,7 +1116,8 @@ class LoRANetwork(torch.nn.Module): lora_down_key = f"{lora.lora_name}.lora_down.weight" lora_up = state_dict[lora_up_key] lora_down = state_dict[lora_down_key] - up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim) + with torch.autocast("cuda"): + up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim) state_dict[lora_up_key] = up.detach() state_dict[lora_down_key] = down.detach() progress.update(1) diff --git a/tests/library/test_lora_util.py b/tests/library/test_lora_util.py new file mode 100644 index 00000000..861a34f9 --- /dev/null +++ b/tests/library/test_lora_util.py @@ -0,0 +1,216 @@ +import torch +import pytest +from library.lora_util import initialize_pissa +from tests.test_util import generate_synthetic_weights + + +def test_initialize_pissa_basic(): + # Create a simple linear layer + org_module = torch.nn.Linear(10, 5) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + torch.nn.init.xavier_uniform_(org_module.weight) + torch.nn.init.zeros_(org_module.bias) + + # Create LoRA layers with matching shapes + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + # Store original weight for comparison + original_weight = org_module.weight.data.clone() + + # Call the initialization function + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) + + # Verify basic properties + assert lora_down.weight.data is not None + assert lora_up.weight.data is not None + assert org_module.weight.data is not None + + # Check that the weights have been modified + assert not torch.equal(original_weight, org_module.weight.data) + + +def test_initialize_pissa_rank_constraints(): + # Test with different rank values + org_module = torch.nn.Linear(20, 10) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + torch.nn.init.xavier_uniform_(org_module.weight) + torch.nn.init.zeros_(org_module.bias) + + # Test with rank less than min dimension + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Test with rank equal to min dimension + lora_down = torch.nn.Linear(20, 10) + lora_up = torch.nn.Linear(10, 10) + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10) + + +def test_initialize_pissa_shape_mismatch(): + # Test with shape mismatch to ensure warning is printed + org_module = torch.nn.Linear(20, 10) + + # Intentionally mismatched shapes to test warning mechanism + lora_down = torch.nn.Linear(20, 5) # Different shape + lora_up = torch.nn.Linear(3, 15) # Different shape + + with pytest.warns(UserWarning): + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + +def test_initialize_pissa_scaling(): + # Test different scaling factors + scales = [0.0, 0.1, 1.0] + + for scale in scales: + org_module = torch.nn.Linear(10, 5) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + original_weight = org_module.weight.data.clone() + + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2) + + # Check that the weight modification follows the scaling + weight_diff = original_weight - org_module.weight.data + expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data) + + torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4) + + +def test_initialize_pissa_dtype(): + # Test with different data types + dtypes = [torch.float16, torch.float32, torch.float64] + + for dtype in dtypes: + org_module = torch.nn.Linear(10, 5).to(dtype=dtype) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) + + # Verify output dtype matches input + assert org_module.weight.dtype == dtype + + +def test_initialize_pissa_svd_properties(): + # Verify SVD decomposition properties + org_module = torch.nn.Linear(20, 10) + lora_down = torch.nn.Linear(20, 3) + lora_up = torch.nn.Linear(3, 10) + + org_module.weight.data = generate_synthetic_weights(org_module.weight) + original_weight = org_module.weight.data.clone() + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + + # Reconstruct the weight + reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data) + + # Check reconstruction is close to original + torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4) + + +def test_initialize_pissa_device_handling(): + # Test different device scenarios + devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")] + + for device in devices: + # Create modules on specific device + org_module = torch.nn.Linear(10, 5).to(device) + lora_down = torch.nn.Linear(10, 2).to(device) + lora_up = torch.nn.Linear(2, 5).to(device) + + # Test initialization with explicit device + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device) + + # Verify modules are on the correct device + assert org_module.weight.data.device.type == device.type + assert lora_down.weight.data.device.type == device.type + assert lora_up.weight.data.device.type == device.type + + +def test_initialize_pissa_dtype_preservation(): + # Test dtype preservation and conversion + dtypes = [torch.float16, torch.float32, torch.float64] + + for dtype in dtypes: + org_module = torch.nn.Linear(10, 5).to(dtype=dtype) + lora_down = torch.nn.Linear(10, 2).to(dtype=dtype) + lora_up = torch.nn.Linear(2, 5).to(dtype=dtype) + + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2) + + assert org_module.weight.dtype == dtype + assert lora_down.weight.dtype == dtype + assert lora_up.weight.dtype == dtype + + +def test_initialize_pissa_rank_limits(): + # Test rank limits + org_module = torch.nn.Linear(10, 5) + + # Test minimum rank (should work) + lora_down_min = torch.nn.Linear(10, 1) + lora_up_min = torch.nn.Linear(1, 5) + initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1) + + # Test maximum rank (rank = min(input_dim, output_dim)) + max_rank = min(10, 5) + lora_down_max = torch.nn.Linear(10, max_rank) + lora_up_max = torch.nn.Linear(max_rank, 5) + initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank) + + +def test_initialize_pissa_numerical_stability(): + # Test with numerically challenging scenarios + scenarios = [ + torch.randn(20, 10) * 1e-10, # Very small values + torch.randn(20, 10) * 1e10, # Very large values + torch.zeros(20, 10), # Zero matrix + ] + + for i, weight_matrix in enumerate(scenarios): + org_module = torch.nn.Linear(20, 10) + org_module.weight.data = weight_matrix + + lora_down = torch.nn.Linear(10, 3) + lora_up = torch.nn.Linear(3, 20) + + try: + initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3) + except Exception as e: + pytest.fail(f"Initialization failed for scenario ({i}): {e}") + + +def test_initialize_pissa_scale_effects(): + # Test different scaling factors + org_module = torch.nn.Linear(10, 5) + original_weight = org_module.weight.data.clone() + + test_scales = [0.0, 0.1, 0.5, 1.0] + + for scale in test_scales: + # Reset module for each test + org_module.weight.data = original_weight.clone() + + lora_down = torch.nn.Linear(10, 2) + lora_up = torch.nn.Linear(2, 5) + + initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2) + + # Verify weight modification proportional to scale + weight_diff = original_weight - org_module.weight.data + + # Approximate check of scaling effect + if scale == 0.0: + torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6) + else: + assert not torch.allclose(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6) diff --git a/tests/networks/test_lora_flux.py b/tests/networks/test_lora_flux.py new file mode 100644 index 00000000..363b50d1 --- /dev/null +++ b/tests/networks/test_lora_flux.py @@ -0,0 +1,609 @@ +import pytest +import torch +import torch.nn as nn +from networks.lora_flux import LoRAModule, LoRANetwork, create_network +from tests.test_util import generate_synthetic_weights +from unittest.mock import MagicMock + + +def test_basic_linear_module_initialization(): + # Test basic Linear module initialization + org_module = nn.Linear(10, 20) + lora_module = LoRAModule(lora_name="test_linear", org_module=org_module, lora_dim=4) + + # Check basic attributes + assert lora_module.lora_name == "test_linear" + assert lora_module.lora_dim == 4 + + # Check LoRA layers + assert isinstance(lora_module.lora_down, nn.Linear) + assert isinstance(lora_module.lora_up, nn.Linear) + + # Check input and output dimensions + assert lora_module.lora_down.in_features == 10 + assert lora_module.lora_down.out_features == 4 + assert lora_module.lora_up.in_features == 4 + assert lora_module.lora_up.out_features == 20 + + +def test_split_dims_initialization(): + # Test initialization with split_dims + org_module = nn.Linear(10, 15) + lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, split_dims=[5, 5, 5]) + + # Check split_dims specific attributes + assert lora_module.split_dims == [5, 5, 5] + assert isinstance(lora_module.lora_down, nn.ModuleList) + assert isinstance(lora_module.lora_up, nn.ModuleList) + + # Check number of split modules + assert len(lora_module.lora_down) == 3 + assert len(lora_module.lora_up) == 3 + + # Check dimensions of split modules + for down, up in zip(lora_module.lora_down, lora_module.lora_up): + assert down.in_features == 10 + assert down.out_features == 4 + assert up.in_features == 4 + assert up.out_features in [5, 5, 5] + + +def test_alpha_scaling(): + # Test alpha scaling + org_module = nn.Linear(10, 20) + + # Default alpha (should be equal to lora_dim) + lora_module1 = LoRAModule(lora_name="test_alpha1", org_module=org_module, lora_dim=4, alpha=0) + assert lora_module1.scale == 1.0 + + # Custom alpha + lora_module2 = LoRAModule(lora_name="test_alpha2", org_module=org_module, lora_dim=4, alpha=2) + assert lora_module2.scale == 0.5 + + +def test_initialization_methods(): + # Test different initialization methods + org_module = nn.Linear(10, 20) + org_module.weight.data = generate_synthetic_weights(org_module.weight) + + # Default initialization + lora_module1 = LoRAModule(lora_name="test_init_default", org_module=org_module, lora_dim=4) + + assert lora_module1.lora_down.weight.shape == (4, 10) + assert lora_module1.lora_up.weight.shape == (20, 4) + + # URAE initialization + lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae") + assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None + assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None + + assert lora_module2.lora_down.weight.shape == (4, 10) + assert lora_module2.lora_up.weight.shape == (20, 4) + + # PISSA initialization + lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa") + assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None + assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None + + assert lora_module3.lora_down.weight.shape == (4, 10) + assert lora_module3.lora_up.weight.shape == (20, 4) + + +@torch.no_grad() +def test_forward_basic_linear(): + # Create a basic linear module + org_module = nn.Linear(10, 20) + org_module.weight.data = torch.testing.make_tensor( + org_module.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + + lora_module = LoRAModule(lora_name="test_forward", org_module=org_module, lora_dim=4, alpha=4, multiplier=1.0) + lora_module.apply_to() + + assert isinstance(lora_module.lora_down, nn.Linear) + assert isinstance(lora_module.lora_up, nn.Linear) + + lora_module.lora_down.weight.data = torch.testing.make_tensor( + lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + lora_module.lora_up.weight.data = torch.testing.make_tensor( + lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + + # Create input + x = torch.ones(5, 10) + + # Perform forward pass + output = lora_module.forward(x) + + # Structural assertions + assert output is not None, "Output should not be None" + assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor" + + # Shape assertions + assert output.shape == (5, 20), "Output shape should match expected dimensions" + + # Type and device assertions + assert output.dtype == torch.float32, "Output should be float32" + assert output.device == x.device, "Output should be on the same device as input" + + +def test_forward_module_dropout(): + # Create a basic linear module + org_module = nn.Linear(10, 20) + + lora_module = LoRAModule( + lora_name="test_module_dropout", + org_module=org_module, + lora_dim=4, + multiplier=1.0, + module_dropout=1.0, # Always drop + ) + + lora_module.apply_to() + + # Create input + x = torch.ones(5, 10) + + # Enable training mode + lora_module.train() + + # Perform forward pass + output = lora_module.forward(x) + + # Check if output is same as original module output + org_output = org_module(x) + torch.testing.assert_close(output, org_output) + + +def test_forward_rank_dropout(): + # Create a basic linear module + org_module = nn.Linear(10, 20) + + lora_module = LoRAModule( + lora_name="test_rank_dropout", + org_module=org_module, + lora_dim=4, + multiplier=1.0, + rank_dropout=0.5, # 50% dropout + ) + + lora_module.apply_to() + + assert isinstance(lora_module.lora_down, nn.Linear) + assert isinstance(lora_module.lora_up, nn.Linear) + + # Make lora weights predictable + lora_module.lora_down.weight.data = torch.testing.make_tensor( + lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + lora_module.lora_up.weight.data = torch.testing.make_tensor( + lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + + # Create input + x = torch.ones(5, 10) + + # Enable training mode + lora_module.train() + + # Perform multiple forward passes to show dropout effect + outputs = [lora_module.forward(x) for _ in range(10)] + + # Check that outputs are not all identical due to rank dropout + differences = [ + torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs)) + ] + assert not all(differences) + + +def test_forward_split_dims(): + # Create a basic linear module with split dimensions + org_module = nn.Linear(10, 15) + + lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, multiplier=1.0, split_dims=[5, 5, 5]) + + lora_module.apply_to() + + assert isinstance(lora_module.lora_down, nn.ModuleList) + assert isinstance(lora_module.lora_up, nn.ModuleList) + + # Make lora weights predictable + for down in lora_module.lora_down: + assert isinstance(down, nn.Linear) + down.weight.data = torch.testing.make_tensor(down.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0) + for up in lora_module.lora_up: + assert isinstance(up, nn.Linear) + up.weight.data = torch.testing.make_tensor(up.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0) + + # Create input + x = torch.ones(5, 10) + + # Perform forward pass + output = lora_module.forward(x) + + # Check output dimensions + assert output.shape == (5, 15) + + +def test_forward_dropout(): + # Create a basic linear module + org_module = nn.Linear(10, 20) + + lora_module = LoRAModule( + lora_name="test_dropout", + org_module=org_module, + lora_dim=4, + multiplier=1.0, + dropout=0.5, # 50% dropout + ) + + lora_module.apply_to() + + assert isinstance(lora_module.lora_down, nn.Linear) + assert isinstance(lora_module.lora_up, nn.Linear) + + # Make lora weights predictable + lora_module.lora_down.weight.data = torch.testing.make_tensor( + lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + lora_module.lora_up.weight.data = torch.testing.make_tensor( + lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0 + ) + + # Create input + x = torch.ones(5, 10) + + # Enable training mode + lora_module.train() + + # Perform multiple forward passes to show dropout effect + outputs = [lora_module.forward(x) for _ in range(10)] + + # Check that outputs are not all identical due to dropout + differences = [ + torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs)) + ] + assert not all(differences) + + +def test_create_network_default_parameters(mock_text_encoder, mock_flux): + # Mock dependencies + mock_ae = MagicMock() + mock_text_encoders = [mock_text_encoder, mock_text_encoder] + + # Call the function with minimal parameters + network = create_network( + multiplier=1.0, network_dim=None, network_alpha=None, ae=mock_ae, text_encoders=mock_text_encoders, flux=mock_flux + ) + + # Assertions + assert network is not None + assert network.multiplier == 1.0 + assert network.lora_dim == 4 # default network_dim + assert network.alpha == 1.0 # default network_alpha + + +@pytest.fixture +def mock_text_encoder(): + class CLIPAttention(nn.Module): + def __init__(self): + super().__init__() + # Add some dummy layers to simulate a CLIPAttention + self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)]) + + class MockTextEncoder(nn.Module): + def __init__(self): + super().__init__() + # Add some dummy layers to simulate a CLIPTextModel + self.attns = torch.nn.ModuleList([CLIPAttention() for _ in range(3)]) + + return MockTextEncoder() + + +@pytest.fixture +def mock_flux(): + class DoubleStreamBlock(nn.Module): + def __init__(self): + super().__init__() + # Add some dummy layers to simulate a DoubleStreamBlock + self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)]) + + class SingleStreamBlock(nn.Module): + def __init__(self): + super().__init__() + # Add some dummy layers to simulate a SingleStreamBlock + self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)]) + + class MockFlux(torch.nn.Module): + def __init__(self): + super().__init__() + # Add some dummy layers to simulate a Flux + self.double_blocks = torch.nn.ModuleList([DoubleStreamBlock() for _ in range(3)]) + self.single_blocks = torch.nn.ModuleList([SingleStreamBlock() for _ in range(3)]) + + return MockFlux() + + +def test_create_network_custom_parameters(mock_text_encoder, mock_flux): + # Mock dependencies + mock_ae = MagicMock() + mock_text_encoders = [mock_text_encoder, mock_text_encoder] + + # Prepare custom parameters + custom_params = { + "conv_dim": 8, + "conv_alpha": 0.5, + "img_attn_dim": 16, + "txt_attn_dim": 16, + "neuron_dropout": 0.1, + "rank_dropout": 0.2, + "module_dropout": 0.3, + "train_blocks": "double", + "split_qkv": "True", + "train_t5xxl": "True", + "in_dims": "[64, 32, 16, 8, 4]", + "verbose": "True", + } + + # Call the function with custom parameters + network = create_network( + multiplier=1.5, + network_dim=8, + network_alpha=2.0, + ae=mock_ae, + text_encoders=mock_text_encoders, + flux=mock_flux, + **custom_params, + ) + + # Assertions + assert network is not None + assert network.multiplier == 1.5 + assert network.lora_dim == 8 + assert network.alpha == 2.0 + assert network.conv_lora_dim == 8 + assert network.conv_alpha == 0.5 + assert network.train_blocks == "double" + assert network.split_qkv is True + assert network.train_t5xxl is True + + +def test_create_network_block_indices(mock_text_encoder, mock_flux): + # Mock dependencies + mock_ae = MagicMock() + mock_text_encoders = [mock_text_encoder, mock_text_encoder] + + # Test block indices parsing + network = create_network( + multiplier=1.0, + network_dim=4, + network_alpha=1.0, + ae=mock_ae, + text_encoders=mock_text_encoders, + flux=mock_flux, + neuron_dropout=None, + **{"train_double_block_indices": "0-2,4", "train_single_block_indices": "1,3"}, + ) + + # Assertions would depend on the exact implementation of parsing + assert network.train_double_block_indices is not None + assert network.train_single_block_indices is not None + + double_block_indices = [ + True, + True, + True, + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ] + single_block_indices = [ + False, + True, + False, + True, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ] + + assert network.train_double_block_indices == double_block_indices + assert network.train_single_block_indices == single_block_indices + + +def test_create_network_loraplus_ratios(mock_text_encoder, mock_flux): + # Mock dependencies + mock_ae = MagicMock() + mock_text_encoders = [mock_text_encoder, mock_text_encoder] + + # Test LoRA+ ratios + network = create_network( + multiplier=1.0, + network_dim=4, + network_alpha=1.0, + ae=mock_ae, + text_encoders=mock_text_encoders, + flux=mock_flux, + neuron_dropout=None, + **{"loraplus_lr_ratio": 2.0, "loraplus_unet_lr_ratio": 1.5, "loraplus_text_encoder_lr_ratio": 1.0}, + ) + + # Verify LoRA+ ratios were set correctly + assert network.loraplus_lr_ratio == 2.0 + assert network.loraplus_unet_lr_ratio == 1.5 + assert network.loraplus_text_encoder_lr_ratio == 1.0 + + +def test_create_network_loraplus_default_ratio(mock_text_encoder, mock_flux): + # Mock dependencies + mock_ae = MagicMock() + mock_text_encoders = [mock_text_encoder, mock_text_encoder] + + # Test when only global LoRA+ ratio is provided + network = create_network( + multiplier=1.0, + network_dim=4, + network_alpha=1.0, + ae=mock_ae, + text_encoders=mock_text_encoders, + flux=mock_flux, + nueral_dropout=None, + **{"loraplus_lr_ratio": 2.0}, + ) + + # Verify only global ratio is set + assert network.loraplus_lr_ratio == 2.0 + assert network.loraplus_unet_lr_ratio is None + assert network.loraplus_text_encoder_lr_ratio is None + + +def test_create_network_invalid_inputs(mock_text_encoder, mock_flux): + # Mock dependencies + mock_ae = MagicMock() + mock_text_encoders = [mock_text_encoder, mock_text_encoder] + mock_flux = mock_flux + + # Test invalid train_blocks + with pytest.raises(AssertionError): + create_network( + multiplier=1.0, + network_dim=4, + network_alpha=1.0, + ae=mock_ae, + text_encoders=mock_text_encoders, + flux=mock_flux, + neuron_dropout=None, + **{"train_blocks": "invalid"}, + ) + + # Test invalid in_dims + with pytest.raises(AssertionError): + create_network( + multiplier=1.0, + network_dim=4, + network_alpha=1.0, + ae=mock_ae, + text_encoders=mock_text_encoders, + flux=mock_flux, + neuron_dropout=None, + **{"in_dims": "[1,2,3]"}, # Should be 5 dimensions + ) + + +def test_lora_network_initialization(mock_text_encoder, mock_flux): + # Test basic initialization with default parameters + lora_network = LoRANetwork(text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux) + + # Check basic attributes + assert lora_network.multiplier == 1.0 + assert lora_network.lora_dim == 4 + assert lora_network.alpha == 1 + assert lora_network.train_blocks == "all" + + # Check LoRA modules are created + assert len(lora_network.text_encoder_loras) > 0 + assert len(lora_network.unet_loras) > 0 + + +def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_flux): + # Test initialization with custom parameters + lora_network = LoRANetwork( + text_encoders=[mock_text_encoder], + unet=mock_flux, + multiplier=0.5, + lora_dim=8, + alpha=2.0, + dropout=0.1, + rank_dropout=0.05, + module_dropout=0.02, + train_blocks="single", + split_qkv=True, + ) + + # Verify custom parameters are set correctly + assert lora_network.multiplier == 0.5 + assert lora_network.lora_dim == 8 + assert lora_network.alpha == 2.0 + assert lora_network.dropout == 0.1 + assert lora_network.rank_dropout == 0.05 + assert lora_network.module_dropout == 0.02 + assert lora_network.train_blocks == "single" + assert lora_network.split_qkv is True + + +def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux): + # Test initialization with custom module dimensions + modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8} + modules_alpha = {"lora_te1_attns_0_layers_0": 2, "lora_unet_double_blocks_0_layers_0": 1} + + lora_network = LoRANetwork( + text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux, modules_dim=modules_dim, modules_alpha=modules_alpha + ) + + # [LoRAModule( + # (lora_down): Linear(in_features=10, out_features=8, bias=False) + # (lora_up): Linear(in_features=8, out_features=15, bias=False) + # (org_module): Linear(in_features=10, out_features=15, bias=True) + # )] + # [LoRAModule( + # (lora_down): Linear(in_features=10, out_features=16, bias=False) + # (lora_up): Linear(in_features=16, out_features=15, bias=False) + # (org_module): Linear(in_features=10, out_features=15, bias=True) + # )] + + assert isinstance(lora_network.unet_loras[0].lora_down, torch.nn.Linear) + assert isinstance(lora_network.unet_loras[0].lora_up, torch.nn.Linear) + assert lora_network.unet_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_unet_double_blocks_0_layers_0"] + assert lora_network.unet_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_unet_double_blocks_0_layers_0"] + assert lora_network.unet_loras[0].alpha == modules_alpha["lora_unet_double_blocks_0_layers_0"] + + assert isinstance(lora_network.text_encoder_loras[0].lora_down, torch.nn.Linear) + assert isinstance(lora_network.text_encoder_loras[0].lora_up, torch.nn.Linear) + assert lora_network.text_encoder_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_te1_attns_0_layers_0"] + assert lora_network.text_encoder_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_te1_attns_0_layers_0"] + assert lora_network.text_encoder_loras[0].alpha == modules_alpha["lora_te1_attns_0_layers_0"]