diff --git a/library/network_utils.py b/library/network_utils.py index 65654740..423c7309 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -1,6 +1,7 @@ import torch import math import warnings +from torch import Tensor from typing import Optional from library.incremental_pca import IncrementalPCA from dataclasses import dataclass @@ -108,73 +109,73 @@ def initialize_urae( lowrank_niter: int = 4, lowrank_seed: Optional[int] = None, ): - org_module_device = org_module.weight.device - org_module_weight_dtype = org_module.weight.data.dtype - org_module_requires_grad = org_module.weight.requires_grad + # Store original device, dtype, and requires_grad status + orig_device = org_module.weight.device + orig_dtype = org_module.weight.data.dtype + orig_requires_grad = org_module.weight.requires_grad - dtype = dtype if dtype is not None else lora_down.weight.data.dtype + # Determine device and dtype to work with device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) - assert isinstance(device, torch.device), f"Invalid device type: {device}" + dtype = dtype if dtype is not None else lora_down.weight.data.dtype + # Move original weight to chosen device and use float32 for numerical stability weight = org_module.weight.data.to(device, dtype=torch.float32) + # Perform SVD decomposition (either directly or with IPCA for memory efficiency) if use_ipca: - # For URAE we need all components to get the "residual" ones ipca = IncrementalPCA( - n_components=None, # Get all components + n_components=None, batch_size=1024, lowrank=use_lowrank, - lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals + lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), lowrank_niter=lowrank_niter, lowrank_seed=lowrank_seed, ) ipca.fit(weight) - # For URAE, use the LAST/SMALLEST singular values - total_rank = min(weight.shape[0], weight.shape[1]) - V_full = ipca.components_.T # [out_features, total_rank] - S_full = ipca.singular_values_ # [total_rank] + # Extract singular values and vectors, focusing on the minor components (smallest singular values) + S_full = ipca.singular_values_ + V_full = ipca.components_.T # Shape: [out_features, total_rank] - # Get the smallest singular values and vectors - Vr = V_full[:, -rank:] # Last rank left singular vectors - Sr = S_full[-rank:] # Last rank singular values - Sr /= rank - - # To get Uhr (last rank right singular vectors), transform basis vectors + # Get identity matrix to transform for right singular vectors identity = torch.eye(weight.shape[1], device=weight.device) - Uhr_full = ipca.transform(identity).T # [total_rank, in_features] - Uhr = Uhr_full[-rank:] # Last rank right singular vectors + Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features] + + # Extract the last 'rank' components (the minor/smallest ones) + Sr = S_full[-rank:] + Vr = V_full[:, -rank:] + Uhr = Uhr_full[-rank:] + + # Scale singular values + Sr = Sr / rank else: - # Standard SVD approach - V, S, Uh = torch.linalg.svd(weight, full_matrices=False) - Vr = V[:, -rank:] + # Direct SVD approach + U, S, Vh = torch.linalg.svd(weight, full_matrices=False) + + # Extract the minor components (smallest singular values) Sr = S[-rank:] - Sr /= rank - Uhr = Uh[-rank:, :] + Vr = U[:, -rank:] + Uhr = Vh[-rank:] - # Create down and up matrices - down = torch.diag(torch.sqrt(Sr)) @ Uhr - up = Vr @ torch.diag(torch.sqrt(Sr)) + # Scale singular values + Sr = Sr / rank - # Get expected shapes - expected_down_shape = lora_down.weight.shape - expected_up_shape = lora_up.weight.shape + # Create the low-rank adapter matrices by splitting the minor components + # Down matrix: scaled right singular vectors with singular values + down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr - # Verify shapes match expected - if down.shape != expected_down_shape: - warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")) + # Up matrix: scaled left singular vectors with singular values + up_matrix = Vr @ torch.diag(torch.sqrt(Sr)) - if up.shape != expected_up_shape: - warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) + # Assign to LoRA modules + lora_down.weight.data = down_matrix.to(device=device, dtype=dtype) + lora_up.weight.data = up_matrix.to(device=device, dtype=dtype) - # Assign to LoRA weights - lora_up.weight.data = up - lora_down.weight.data = down - - # Optionally, subtract from original weight - weight = weight - scale * (up @ down) - org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype) - org_module.weight.requires_grad = org_module_requires_grad + # Update the original weight by removing the minor components + # This is equivalent to keeping only the major components + modified_weight = weight - scale * (up_matrix @ down_matrix) + org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype) + org_module.weight.requires_grad = orig_requires_grad # PiSSA: Principal Singular Values and Singular Vectors Adaptation @@ -269,3 +270,107 @@ def initialize_pissa( org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype) org_module.weight.requires_grad = org_module_requires_grad + +def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): + # Calculate ΔW = A'B' - AB + delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) + + # We need to create new low-rank matrices that represent this delta + U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) + + # Take the top 2*r singular values (as suggested in the paper) + rank = rank * 2 + rank = min(rank, len(S)) # Make sure we don't exceed available singular values + + # Create new LoRA matrices + new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) + new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] + + # These matrices can now be used as standard LoRA weights + return new_up, new_down + + +def convert_urae_to_standard_lora( + trained_up: Tensor, + trained_down: Tensor, + orig_up: Tensor, + orig_down: Tensor, + initial_alpha: float | None = None, + rank: int | None = None, +): + """ + Convert URAE trained weights to standard LoRA format + + Args: + trained_up: The trained URAE Up matrix + trained_down: The trained URAE Down matrix + orig_up: The original up matrix before training + orig_down: The original down matrix before training + initial_alpha: The alpha value used during URAE training (if any) + rank: The rank for the standard LoRA (if None, uses the rank of trained_A) + + Returns: + lora_up: Standard LoRA up matrix + lora_down: Standard LoRA down matrix + alpha: Appropriate alpha value for the LoRA + """ + # Calculate the weight delta + delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) + + # Perform SVD on the delta + U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False) + + # If rank is not specified, use the same rank as the trained matrices + if rank is None: + rank = trained_up.shape[1] + else: + # Ensure we don't exceed available singular values + rank = min(rank, len(S)) + + # Create standard LoRA matrices using top singular values + # This is now standard LoRA (using top values), not URAE (which used bottom values during training) + lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) + lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] + + # Method 1: Preserve the Frobenius norm of the delta + original_effect: float = torch.norm(delta_w, p="fro").item() + unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item() + + # The scaling factor in lora is (alpha/r), so: + # alpha/r × ||AB|| = ||delta_W|| + # alpha = r × ||delta_W|| / ||AB|| + if unscaled_lora_effect > 0: + norm_based_alpha = rank * (original_effect / unscaled_lora_effect) + else: + norm_based_alpha = 1.0 # Fallback + + # Method 2: If initial_alpha is provided, adjust based on rank change + if initial_alpha is not None: + initial_rank = trained_up.shape[1] + # Scale alpha proportionally if rank changed + rank_adjusted_alpha = initial_alpha * (rank / initial_rank) + else: + rank_adjusted_alpha = None + + # Choose the appropriate alpha + if rank_adjusted_alpha is not None: + # Use the rank-adjusted alpha, but ensure it's not too different from norm-based + # Cap the difference to avoid extreme values + alpha = rank_adjusted_alpha + # Optional: Cap alpha to be within a reasonable range of norm_based_alpha + if norm_based_alpha > 0: + max_factor = 5.0 # Allow up to 5x difference + upper_bound = norm_based_alpha * max_factor + lower_bound = norm_based_alpha / max_factor + alpha = min(max(alpha, lower_bound), upper_bound) + else: + # Use norm-based alpha + alpha = norm_based_alpha + + # Round to a clean value for better usability + alpha = round(alpha, 2) + + # Ensure alpha is positive and within reasonable bounds + alpha = max(0.1, min(alpha, 1024.0)) + + return lora_up, lora_down, alpha diff --git a/networks/lora_flux.py b/networks/lora_flux.py index e8cd528f..aeebdcdd 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -19,7 +19,14 @@ from tqdm import tqdm import re from library.utils import setup_logging from library.device_utils import clean_memory_on_device -from library.network_utils import initialize_lora, initialize_pissa, initialize_urae, initialize_parse_opts +from library.network_utils import ( + initialize_lora, + initialize_pissa, + initialize_urae, + initialize_parse_opts, + convert_pissa_to_standard_lora, + convert_urae_to_standard_lora, +) setup_logging() import logging @@ -49,6 +56,7 @@ class LoRAModule(torch.nn.Module): split_dims: Optional[List[int]] = None, ggpo_beta: Optional[float] = None, ggpo_sigma: Optional[float] = None, + initialize: Optional[str] = None, ): """ if alpha == 0 or None, alpha is rank (no scaling). @@ -101,7 +109,6 @@ class LoRAModule(torch.nn.Module): self.initialize_norm_cache(org_module.weight) self.org_module_shape: tuple[int] = org_module.weight.shape - def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]): """ Initialize the weights for the LoRA @@ -113,12 +120,16 @@ class LoRAModule(torch.nn.Module): if initialize is not None: params = initialize_parse_opts(initialize) if initialize[:4] == "urae": - initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)) + initialize_urae( + org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params) + ) # Need to store the original weights so we can get a plain LoRA out self._org_lora_up = self.lora_up.weight.data.detach().clone() self._org_lora_down = self.lora_down.weight.data.detach().clone() elif initialize[:5] == "pissa": - initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)) + initialize_pissa( + org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params) + ) # Need to store the original weights so we can get a plain LoRA out self._org_lora_up = self.lora_up.weight.data.detach().clone() self._org_lora_down = self.lora_down.weight.data.detach().clone() @@ -149,7 +160,6 @@ class LoRAModule(torch.nn.Module): self._org_lora_up = self._org_lora_up.to("cpu") self._org_lora_down = self._org_lora_down.to("cpu") - def apply_to(self): self.org_forward = self.org_module.forward self.org_module.forward = self.forward @@ -607,7 +617,7 @@ def create_network( if verbose is not None: verbose = True if verbose == "True" else False - # Computation device, used in initialization + # Computation device, used in initialization comp_device = kwargs.get("comp_device", None) if comp_device is not None: comp_device = torch.device(comp_device) @@ -922,7 +932,9 @@ class LoRANetwork(torch.nn.Module): elif "single" in lora_name and "linear1" in lora_name: split_dims = [3072] * 3 + [12288] - assert module_class is LoRAModule or module_class is LoRAInfModule, f"Module class is not valid {type(module_class)}" + assert module_class is LoRAModule or module_class is LoRAInfModule, ( + f"Module class is not valid {type(module_class)}" + ) lora = module_class( lora_name, child_module, @@ -960,7 +972,7 @@ class LoRANetwork(torch.nn.Module): logger.info(f"Initialize Text Encoder LoRA using {initialize}") text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index+1}.") + logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index + 1}.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped @@ -1313,66 +1325,33 @@ class LoRANetwork(torch.nn.Module): # Need to decompose the parameters into a LoRA format if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"): loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras - def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): - # Calculate ΔW = A'B' - AB - delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) - - # We need to create new low-rank matrices that represent this delta - U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) - - # Take the top 2*r singular values (as suggested in the paper) - rank = rank * 2 - rank = min(rank, len(S)) # Make sure we don't exceed available singular values - - # Create new LoRA matrices - new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) - new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] - - # These matrices can now be used as standard LoRA weights - return new_up, new_down - - def convert_urae_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): - # Calculate ΔW = A'B' - AB - delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) - - # We need to create new low-rank matrices that represent this delta - U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) - - # For URAE, we want to focus on the smallest singular values - # Take the bottom rank*2 singular values (opposite of PiSSA which takes the top ones) - total_rank = len(S) - rank_to_use = min(rank * 2, total_rank) - - if rank_to_use < total_rank: - # Use the smallest singular values and vectors - selected_U = U[:, -rank_to_use:] - selected_S = S[-rank_to_use:] - selected_V = V[-rank_to_use:, :] - else: - # If we'd use all values, just use the standard approach but with a note - print("Warning: Requested rank is too large for URAE specialty, using all singular values") - selected_U = U - selected_S = S - selected_V = V - - # Create new LoRA matrices - new_up = selected_U @ torch.diag(torch.sqrt(selected_S)) - new_down = torch.diag(torch.sqrt(selected_S)) @ selected_V - - # These matrices can now be used as standard LoRA weights - return new_up, new_down with torch.no_grad(): progress = tqdm(total=len(loras), desc="Converting") for lora in loras: lora_up_key = f"{lora.lora_name}.lora_up.weight" lora_down_key = f"{lora.lora_name}.lora_down.weight" + lora_alpha_key = f"{lora.lora_name}.alpha" lora_up = state_dict[lora_up_key] lora_down = state_dict[lora_down_key] if self.initialize[:4] == "urae": - up, down = convert_urae_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) + up, down, alpha = convert_urae_to_standard_lora( + lora_up, + lora_down, + lora._org_lora_up.to(lora_up.device), + lora._org_lora_down.to(lora_up.device), + initial_alpha=lora.alpha.item(), + rank=lora.lora_dim, + ) + state_dict[lora_alpha_key] = torch.tensor(alpha).to(dtype=lora.alpha.dtype) elif self.initialize[:5] == "pissa": - up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) + up, down = convert_pissa_to_standard_lora( + lora_up, + lora_down, + lora._org_lora_up.to(lora_up.device), + lora._org_lora_down.to(lora_up.device), + lora.lora_dim, + ) # TODO: Capture option if we should offload # offload to CPU diff --git a/tests/library/test_network_utils_urae.py b/tests/library/test_network_utils_urae.py new file mode 100644 index 00000000..581e13b8 --- /dev/null +++ b/tests/library/test_network_utils_urae.py @@ -0,0 +1,241 @@ +import pytest +import torch + +from library.network_utils import convert_urae_to_standard_lora + + +class TestConvertURAEToStandardLoRA: + @pytest.fixture + def sample_matrices(self): + """Create sample matrices for testing""" + # Original up matrix (4x2) + orig_up = torch.tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0] + ]) + + # Original down matrix (2x6) + orig_down = torch.tensor([ + [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], + [0.7, 0.8, 0.9, 1.0, 1.1, 1.2] + ]) + + # Trained up matrix (4x2) - same shape as orig_up but with changed values + trained_up = torch.tensor([ + [1.1, 2.1], + [3.1, 4.1], + [5.1, 6.1], + [7.1, 8.1] + ]) + + # Trained down matrix (2x6) - same shape as orig_down but with changed values + trained_down = torch.tensor([ + [0.15, 0.25, 0.35, 0.45, 0.55, 0.65], + [0.75, 0.85, 0.95, 1.05, 1.15, 1.25] + ]) + + return { + 'orig_up': orig_up, + 'orig_down': orig_down, + 'trained_up': trained_up, + 'trained_down': trained_down + } + + def test_basic_conversion(self, sample_matrices): + """Test the basic functionality of convert_urae_to_standard_lora""" + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + sample_matrices['trained_up'], + sample_matrices['trained_down'], + sample_matrices['orig_up'], + sample_matrices['orig_down'] + ) + + # Check shapes + assert lora_up.shape[0] == sample_matrices['trained_up'].shape[0] # Same number of rows as trained_up + assert lora_up.shape[1] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up + assert lora_down.shape[0] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up + assert lora_down.shape[1] == sample_matrices['trained_down'].shape[1] # Same number of columns as trained_down + + # Check alpha is a reasonable value + assert 0.1 <= alpha <= 1024.0 + + # Check that lora_up @ lora_down approximates the weight delta + delta = (sample_matrices['trained_up'] @ sample_matrices['trained_down']) - (sample_matrices['orig_up'] @ sample_matrices['orig_down']) + + # The approximation should be close in Frobenius norm after scaling + lora_effect = lora_up @ lora_down + delta_norm = torch.norm(delta, p="fro").item() + lora_norm = torch.norm(lora_effect, p="fro").item() + + # Either they are close, or the alpha scaling brings them close + scaled_lora_effect = (alpha / sample_matrices['trained_up'].shape[1]) * lora_effect + scaled_lora_norm = torch.norm(scaled_lora_effect, p="fro").item() + + # At least one of these should be true + assert abs(delta_norm - lora_norm) < 1e-4 or abs(delta_norm - scaled_lora_norm) < 1e-4 + + def test_specified_rank(self, sample_matrices): + """Test conversion with a specified rank""" + new_rank = 1 # Lower than trained_up's rank of 2 + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + sample_matrices['trained_up'], + sample_matrices['trained_down'], + sample_matrices['orig_up'], + sample_matrices['orig_down'], + rank=new_rank + ) + + # Check that the new rank is used + assert lora_up.shape[1] == new_rank + assert lora_down.shape[0] == new_rank + + # Should still produce a reasonable alpha + assert 0.1 <= alpha <= 1024.0 + + def test_with_initial_alpha(self, sample_matrices): + """Test conversion with a specified initial alpha""" + initial_alpha = 16.0 + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + sample_matrices['trained_up'], + sample_matrices['trained_down'], + sample_matrices['orig_up'], + sample_matrices['orig_down'], + initial_alpha=initial_alpha + ) + + # Alpha should be influenced by initial_alpha but may be adjusted + # Since we're using same rank, should be reasonably close to initial_alpha + assert 0.1 <= alpha <= 1024.0 + # Should at least preserve the order of magnitude in typical cases + assert abs(alpha - initial_alpha) <= initial_alpha * 4.0 + + def test_large_initial_alpha(self, sample_matrices): + """Test conversion with a very large initial alpha that should be capped""" + initial_alpha = 2000.0 # Larger than the 1024.0 cap + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + sample_matrices['trained_up'], + sample_matrices['trained_down'], + sample_matrices['orig_up'], + sample_matrices['orig_down'], + initial_alpha=initial_alpha + ) + + # Alpha should be capped at 1024.0 + assert alpha <= 1024.0 + + def test_very_small_initial_alpha(self, sample_matrices): + """Test conversion with a very small initial alpha that should be floored""" + initial_alpha = 0.01 # Smaller than the 0.1 floor + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + sample_matrices['trained_up'], + sample_matrices['trained_down'], + sample_matrices['orig_up'], + sample_matrices['orig_down'], + initial_alpha=initial_alpha + ) + + # Alpha should be floored at 0.1 + assert alpha >= 0.1 + + def test_change_rank_with_initial_alpha(self, sample_matrices): + """Test conversion with both rank change and initial alpha""" + initial_alpha = 16.0 + new_rank = 1 # Half of original rank 2 + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + sample_matrices['trained_up'], + sample_matrices['trained_down'], + sample_matrices['orig_up'], + sample_matrices['orig_down'], + initial_alpha=initial_alpha, + rank=new_rank + ) + + # Check shapes + assert lora_up.shape[1] == new_rank + assert lora_down.shape[0] == new_rank + + # Alpha should be adjusted for the rank change (approx halved in this case) + expected_alpha = initial_alpha * (new_rank / sample_matrices['trained_up'].shape[1]) + # Allow some tolerance for adjustments from norm-based capping + assert abs(alpha - expected_alpha) <= expected_alpha * 4.0 or alpha >= 0.1 + + def test_zero_delta(self): + """Test conversion when delta is zero""" + # Create matrices where the delta will be zero + dim_in, rank, dim_out = 4, 2, 6 + + # Create identical matrices for original and trained + orig_up = torch.randn(dim_in, rank) + orig_down = torch.randn(rank, dim_out) + trained_up = orig_up.clone() + trained_down = orig_down.clone() + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + trained_up, + trained_down, + orig_up, + orig_down + ) + + # Should still return matrices of correct shape + assert lora_up.shape == (dim_in, rank) + assert lora_down.shape == (rank, dim_out) + + # Alpha should be at least the minimum + assert alpha >= 0.1 + + def test_large_dimensions(self): + """Test with larger matrix dimensions""" + dim_in, rank, dim_out = 100, 8, 200 + + orig_up = torch.randn(dim_in, rank) + orig_down = torch.randn(rank, dim_out) + trained_up = orig_up + 0.01 * torch.randn(dim_in, rank) # Small perturbation + trained_down = orig_down + 0.01 * torch.randn(rank, dim_out) # Small perturbation + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + trained_up, + trained_down, + orig_up, + orig_down + ) + + # Check shapes + assert lora_up.shape == (dim_in, rank) + assert lora_down.shape == (rank, dim_out) + + # Should produce a reasonable alpha + assert 0.1 <= alpha <= 1024.0 + + def test_rank_exceeding_singular_values(self): + """Test when requested rank exceeds available singular values""" + # Small matrices with limited rank + dim_in, rank, dim_out = 3, 2, 3 + + orig_up = torch.randn(dim_in, rank) + orig_down = torch.randn(rank, dim_out) + trained_up = orig_up + 0.1 * torch.randn(dim_in, rank) + trained_down = orig_down + 0.1 * torch.randn(rank, dim_out) + + # Request rank larger than possible + too_large_rank = 10 + + lora_up, lora_down, alpha = convert_urae_to_standard_lora( + trained_up, + trained_down, + orig_up, + orig_down, + rank=too_large_rank + ) + + # Rank should be limited to min(dim_in, dim_out, S.size) + max_possible_rank = min(dim_in, dim_out) + assert lora_up.shape[1] <= max_possible_rank + assert lora_down.shape[0] <= max_possible_rank