import torch import math import warnings from torch import Tensor from typing import Optional from dataclasses import dataclass @dataclass class InitializeParams: """Parameters for initialization methods (PiSSA, URAE)""" use_lowrank: bool = False # lowrank_q: Optional[int] = None lowrank_niter: int = 4 def initialize_parse_opts(key: str) -> InitializeParams: """ Parse initialization parameters from a string key. Format examples: - "pissa" -> Default PiSSA with lowrank=True, niter=4 - "pissa_niter_4" -> PiSSA with niter=4 - "pissa_lowrank_false" -> PiSSA without lowrank Args: key: String key to parse Returns: InitializeParams object with parsed parameters """ parts = key.lower().split("_") # Extract the method (first part) method = parts[0] if method not in ["pissa", "urae"]: raise ValueError(f"Unknown initialization method: {method}") # Start with default parameters params = InitializeParams() # Parse the remaining parts i = 1 while i < len(parts): if parts[i] == "lowrank": if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]: params.use_lowrank = parts[i + 1] == "true" i += 2 else: params.use_lowrank = True i += 1 elif parts[i] == "niter": if i + 1 < len(parts) and parts[i + 1].isdigit(): params.lowrank_niter = int(parts[i + 1]) params.use_lowrank = True i += 2 else: i += 1 else: # Skip unknown parameter i += 1 return params def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module): torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) torch.nn.init.zeros_(lora_up.weight) # URAE: Ultra-Resolution Adaptation with Ease def initialize_urae( org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, **kwargs, ): # Store original device, dtype, and requires_grad status orig_device = org_module.weight.device orig_dtype = org_module.weight.data.dtype orig_requires_grad = org_module.weight.requires_grad # Determine device and dtype to work with device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) dtype = dtype if dtype is not None else lora_down.weight.data.dtype # Move original weight to chosen device and use float32 for numerical stability weight = org_module.weight.data.to(device, dtype=torch.float32) with torch.no_grad(): # Direct SVD approach U, S, Vh = torch.linalg.svd(weight, full_matrices=False) # Extract the minor components (smallest singular values) Sr = S[-rank:] Vr = U[:, -rank:] Uhr = Vh[-rank:] # Scale singular values Sr = Sr / rank # Create the low-rank adapter matrices by splitting the minor components # Down matrix: scaled right singular vectors with singular values down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr # Up matrix: scaled left singular vectors with singular values up_matrix = Vr @ torch.diag(torch.sqrt(Sr)) # Assign to LoRA modules lora_down.weight.data = down_matrix.to(device=device, dtype=dtype) lora_up.weight.data = up_matrix.to(device=device, dtype=dtype) # Update the original weight by removing the minor components # This is equivalent to keeping only the major components modified_weight = weight - scale * (up_matrix @ down_matrix) org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype) org_module.weight.requires_grad = orig_requires_grad # PiSSA: Principal Singular Values and Singular Vectors Adaptation def initialize_pissa( org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, use_lowrank: bool = False, # lowrank_q: Optional[int] = None, lowrank_niter: int = 4, ): org_module_device = org_module.weight.device org_module_weight_dtype = org_module.weight.data.dtype org_module_requires_grad = org_module.weight.requires_grad dtype = dtype if dtype is not None else lora_down.weight.data.dtype device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) assert isinstance(device, torch.device), f"Invalid device type: {device}" weight = org_module.weight.data.clone().to(device, dtype=torch.float32) with torch.no_grad(): if use_lowrank: # q_value = lowrank_q if lowrank_q is not None else 2 * rank Vr, Sr, Ur = torch.svd_lowrank(weight.data, q=rank, niter=lowrank_niter) Sr /= rank Uhr = Ur.t() else: # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False) Vr = V[:, :rank] Sr = S[:rank] Sr /= rank Uhr = Uh[:rank] down = torch.diag(torch.sqrt(Sr)) @ Uhr up = Vr @ torch.diag(torch.sqrt(Sr)) # Get expected shapes expected_down_shape = lora_down.weight.shape expected_up_shape = lora_up.weight.shape # Verify shapes match expected or reshape appropriately if down.shape != expected_down_shape: warnings.warn(UserWarning(f"Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")) if up.shape != expected_up_shape: warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) lora_up.weight.data = up.to(lora_up.weight.data.device, dtype=lora_up.weight.dtype) lora_down.weight.data = down.to(lora_down.weight.data.device, dtype=lora_down.weight.dtype) weight = weight.data - scale * (up @ down) org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype) org_module.weight.requires_grad = org_module_requires_grad def convert_pissa_to_standard_lora( trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int ): with torch.no_grad(): # Calculate ΔW = A'B' - AB delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) # We need to create new low-rank matrices that represent this delta U, S, V = torch.linalg.svd(delta_w.to(trained_up.device, dtype=torch.float32), full_matrices=False) # Take the top 2*r singular values (as suggested in the paper) rank = rank * 2 rank = min(rank, len(S)) # Make sure we don't exceed available singular values # Create new LoRA matrices new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] # These matrices can now be used as standard LoRA weights return new_up, new_down def convert_urae_to_standard_lora( trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, initial_alpha: float | None = None, rank: int | None = None, ): """ Convert URAE trained weights to standard LoRA format Args: trained_up: The trained URAE Up matrix trained_down: The trained URAE Down matrix orig_up: The original up matrix before training orig_down: The original down matrix before training initial_alpha: The alpha value used during URAE training (if any) rank: The rank for the standard LoRA (if None, uses the rank of trained_A) Returns: lora_up: Standard LoRA up matrix lora_down: Standard LoRA down matrix alpha: Appropriate alpha value for the LoRA """ with torch.no_grad(): # Calculate the weight delta delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) # Perform SVD on the delta U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False) # If rank is not specified, use the same rank as the trained matrices if rank is None: rank = trained_up.shape[1] else: # Ensure we don't exceed available singular values rank = min(rank, len(S)) # Create standard LoRA matrices using top singular values # This is now standard LoRA (using top values), not URAE (which used bottom values during training) lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank])) lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :] # Method 1: Preserve the Frobenius norm of the delta original_effect: float = torch.norm(delta_w, p="fro").item() unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item() # The scaling factor in lora is (alpha/r), so: # alpha/r × ||AB|| = ||delta_W|| # alpha = r × ||delta_W|| / ||AB|| if unscaled_lora_effect > 0: norm_based_alpha = rank * (original_effect / unscaled_lora_effect) else: norm_based_alpha = 1.0 # Fallback # Method 2: If initial_alpha is provided, adjust based on rank change if initial_alpha is not None: initial_rank = trained_up.shape[1] # Scale alpha proportionally if rank changed rank_adjusted_alpha = initial_alpha * (rank / initial_rank) else: rank_adjusted_alpha = None # Choose the appropriate alpha if rank_adjusted_alpha is not None: # Use the rank-adjusted alpha, but ensure it's not too different from norm-based # Cap the difference to avoid extreme values alpha = rank_adjusted_alpha # Optional: Cap alpha to be within a reasonable range of norm_based_alpha if norm_based_alpha > 0: max_factor = 5.0 # Allow up to 5x difference upper_bound = norm_based_alpha * max_factor lower_bound = norm_based_alpha / max_factor alpha = min(max(alpha, lower_bound), upper_bound) else: # Use norm-based alpha alpha = norm_based_alpha # Round to a clean value for better usability alpha = round(alpha, 2) # Ensure alpha is positive and within reasonable bounds alpha = max(0.1, min(alpha, 1024.0)) return lora_up, lora_down, alpha