mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Update URAE initialization, conversion. Add tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import warnings
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from library.incremental_pca import IncrementalPCA
|
||||
from dataclasses import dataclass
|
||||
@@ -108,73 +109,73 @@ def initialize_urae(
|
||||
lowrank_niter: int = 4,
|
||||
lowrank_seed: Optional[int] = None,
|
||||
):
|
||||
org_module_device = org_module.weight.device
|
||||
org_module_weight_dtype = org_module.weight.data.dtype
|
||||
org_module_requires_grad = org_module.weight.requires_grad
|
||||
# Store original device, dtype, and requires_grad status
|
||||
orig_device = org_module.weight.device
|
||||
orig_dtype = org_module.weight.data.dtype
|
||||
orig_requires_grad = org_module.weight.requires_grad
|
||||
|
||||
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
|
||||
# Determine device and dtype to work with
|
||||
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
||||
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
||||
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
|
||||
|
||||
# Move original weight to chosen device and use float32 for numerical stability
|
||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||
|
||||
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
|
||||
if use_ipca:
|
||||
# For URAE we need all components to get the "residual" ones
|
||||
ipca = IncrementalPCA(
|
||||
n_components=None, # Get all components
|
||||
n_components=None,
|
||||
batch_size=1024,
|
||||
lowrank=use_lowrank,
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
|
||||
lowrank_niter=lowrank_niter,
|
||||
lowrank_seed=lowrank_seed,
|
||||
)
|
||||
ipca.fit(weight)
|
||||
|
||||
# For URAE, use the LAST/SMALLEST singular values
|
||||
total_rank = min(weight.shape[0], weight.shape[1])
|
||||
V_full = ipca.components_.T # [out_features, total_rank]
|
||||
S_full = ipca.singular_values_ # [total_rank]
|
||||
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
|
||||
S_full = ipca.singular_values_
|
||||
V_full = ipca.components_.T # Shape: [out_features, total_rank]
|
||||
|
||||
# Get the smallest singular values and vectors
|
||||
Vr = V_full[:, -rank:] # Last rank left singular vectors
|
||||
Sr = S_full[-rank:] # Last rank singular values
|
||||
Sr /= rank
|
||||
|
||||
# To get Uhr (last rank right singular vectors), transform basis vectors
|
||||
# Get identity matrix to transform for right singular vectors
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
Uhr_full = ipca.transform(identity).T # [total_rank, in_features]
|
||||
Uhr = Uhr_full[-rank:] # Last rank right singular vectors
|
||||
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
|
||||
|
||||
# Extract the last 'rank' components (the minor/smallest ones)
|
||||
Sr = S_full[-rank:]
|
||||
Vr = V_full[:, -rank:]
|
||||
Uhr = Uhr_full[-rank:]
|
||||
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
else:
|
||||
# Standard SVD approach
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, -rank:]
|
||||
# Direct SVD approach
|
||||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# Extract the minor components (smallest singular values)
|
||||
Sr = S[-rank:]
|
||||
Sr /= rank
|
||||
Uhr = Uh[-rank:, :]
|
||||
Vr = U[:, -rank:]
|
||||
Uhr = Vh[-rank:]
|
||||
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
# Create the low-rank adapter matrices by splitting the minor components
|
||||
# Down matrix: scaled right singular vectors with singular values
|
||||
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
|
||||
# Verify shapes match expected
|
||||
if down.shape != expected_down_shape:
|
||||
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
||||
# Up matrix: scaled left singular vectors with singular values
|
||||
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
# Assign to LoRA modules
|
||||
lora_down.weight.data = down_matrix.to(device=device, dtype=dtype)
|
||||
lora_up.weight.data = up_matrix.to(device=device, dtype=dtype)
|
||||
|
||||
# Assign to LoRA weights
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
# Optionally, subtract from original weight
|
||||
weight = weight - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
org_module.weight.requires_grad = org_module_requires_grad
|
||||
# Update the original weight by removing the minor components
|
||||
# This is equivalent to keeping only the major components
|
||||
modified_weight = weight - scale * (up_matrix @ down_matrix)
|
||||
org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype)
|
||||
org_module.weight.requires_grad = orig_requires_grad
|
||||
|
||||
|
||||
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
||||
@@ -269,3 +270,107 @@ def initialize_pissa(
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
org_module.weight.requires_grad = org_module_requires_grad
|
||||
|
||||
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||
|
||||
# Create new LoRA matrices
|
||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# These matrices can now be used as standard LoRA weights
|
||||
return new_up, new_down
|
||||
|
||||
|
||||
def convert_urae_to_standard_lora(
|
||||
trained_up: Tensor,
|
||||
trained_down: Tensor,
|
||||
orig_up: Tensor,
|
||||
orig_down: Tensor,
|
||||
initial_alpha: float | None = None,
|
||||
rank: int | None = None,
|
||||
):
|
||||
"""
|
||||
Convert URAE trained weights to standard LoRA format
|
||||
|
||||
Args:
|
||||
trained_up: The trained URAE Up matrix
|
||||
trained_down: The trained URAE Down matrix
|
||||
orig_up: The original up matrix before training
|
||||
orig_down: The original down matrix before training
|
||||
initial_alpha: The alpha value used during URAE training (if any)
|
||||
rank: The rank for the standard LoRA (if None, uses the rank of trained_A)
|
||||
|
||||
Returns:
|
||||
lora_up: Standard LoRA up matrix
|
||||
lora_down: Standard LoRA down matrix
|
||||
alpha: Appropriate alpha value for the LoRA
|
||||
"""
|
||||
# Calculate the weight delta
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# Perform SVD on the delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# If rank is not specified, use the same rank as the trained matrices
|
||||
if rank is None:
|
||||
rank = trained_up.shape[1]
|
||||
else:
|
||||
# Ensure we don't exceed available singular values
|
||||
rank = min(rank, len(S))
|
||||
|
||||
# Create standard LoRA matrices using top singular values
|
||||
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
|
||||
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# Method 1: Preserve the Frobenius norm of the delta
|
||||
original_effect: float = torch.norm(delta_w, p="fro").item()
|
||||
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
|
||||
|
||||
# The scaling factor in lora is (alpha/r), so:
|
||||
# alpha/r × ||AB|| = ||delta_W||
|
||||
# alpha = r × ||delta_W|| / ||AB||
|
||||
if unscaled_lora_effect > 0:
|
||||
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
|
||||
else:
|
||||
norm_based_alpha = 1.0 # Fallback
|
||||
|
||||
# Method 2: If initial_alpha is provided, adjust based on rank change
|
||||
if initial_alpha is not None:
|
||||
initial_rank = trained_up.shape[1]
|
||||
# Scale alpha proportionally if rank changed
|
||||
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
|
||||
else:
|
||||
rank_adjusted_alpha = None
|
||||
|
||||
# Choose the appropriate alpha
|
||||
if rank_adjusted_alpha is not None:
|
||||
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
|
||||
# Cap the difference to avoid extreme values
|
||||
alpha = rank_adjusted_alpha
|
||||
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
|
||||
if norm_based_alpha > 0:
|
||||
max_factor = 5.0 # Allow up to 5x difference
|
||||
upper_bound = norm_based_alpha * max_factor
|
||||
lower_bound = norm_based_alpha / max_factor
|
||||
alpha = min(max(alpha, lower_bound), upper_bound)
|
||||
else:
|
||||
# Use norm-based alpha
|
||||
alpha = norm_based_alpha
|
||||
|
||||
# Round to a clean value for better usability
|
||||
alpha = round(alpha, 2)
|
||||
|
||||
# Ensure alpha is positive and within reasonable bounds
|
||||
alpha = max(0.1, min(alpha, 1024.0))
|
||||
|
||||
return lora_up, lora_down, alpha
|
||||
|
||||
@@ -19,7 +19,14 @@ from tqdm import tqdm
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.device_utils import clean_memory_on_device
|
||||
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae, initialize_parse_opts
|
||||
from library.network_utils import (
|
||||
initialize_lora,
|
||||
initialize_pissa,
|
||||
initialize_urae,
|
||||
initialize_parse_opts,
|
||||
convert_pissa_to_standard_lora,
|
||||
convert_urae_to_standard_lora,
|
||||
)
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -49,6 +56,7 @@ class LoRAModule(torch.nn.Module):
|
||||
split_dims: Optional[List[int]] = None,
|
||||
ggpo_beta: Optional[float] = None,
|
||||
ggpo_sigma: Optional[float] = None,
|
||||
initialize: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -101,7 +109,6 @@ class LoRAModule(torch.nn.Module):
|
||||
self.initialize_norm_cache(org_module.weight)
|
||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||
|
||||
|
||||
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
|
||||
"""
|
||||
Initialize the weights for the LoRA
|
||||
@@ -113,12 +120,16 @@ class LoRAModule(torch.nn.Module):
|
||||
if initialize is not None:
|
||||
params = initialize_parse_opts(initialize)
|
||||
if initialize[:4] == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
|
||||
initialize_urae(
|
||||
org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)
|
||||
)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif initialize[:5] == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
|
||||
initialize_pissa(
|
||||
org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)
|
||||
)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
@@ -149,7 +160,6 @@ class LoRAModule(torch.nn.Module):
|
||||
self._org_lora_up = self._org_lora_up.to("cpu")
|
||||
self._org_lora_down = self._org_lora_down.to("cpu")
|
||||
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
@@ -607,7 +617,7 @@ def create_network(
|
||||
if verbose is not None:
|
||||
verbose = True if verbose == "True" else False
|
||||
|
||||
# Computation device, used in initialization
|
||||
# Computation device, used in initialization
|
||||
comp_device = kwargs.get("comp_device", None)
|
||||
if comp_device is not None:
|
||||
comp_device = torch.device(comp_device)
|
||||
@@ -922,7 +932,9 @@ class LoRANetwork(torch.nn.Module):
|
||||
elif "single" in lora_name and "linear1" in lora_name:
|
||||
split_dims = [3072] * 3 + [12288]
|
||||
|
||||
assert module_class is LoRAModule or module_class is LoRAInfModule, f"Module class is not valid {type(module_class)}"
|
||||
assert module_class is LoRAModule or module_class is LoRAInfModule, (
|
||||
f"Module class is not valid {type(module_class)}"
|
||||
)
|
||||
lora = module_class(
|
||||
lora_name,
|
||||
child_module,
|
||||
@@ -960,7 +972,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
logger.info(f"Initialize Text Encoder LoRA using {initialize}")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index+1}.")
|
||||
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index + 1}.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
|
||||
@@ -1313,66 +1325,33 @@ class LoRANetwork(torch.nn.Module):
|
||||
# Need to decompose the parameters into a LoRA format
|
||||
if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"):
|
||||
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||
|
||||
# Create new LoRA matrices
|
||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# These matrices can now be used as standard LoRA weights
|
||||
return new_up, new_down
|
||||
|
||||
def convert_urae_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# For URAE, we want to focus on the smallest singular values
|
||||
# Take the bottom rank*2 singular values (opposite of PiSSA which takes the top ones)
|
||||
total_rank = len(S)
|
||||
rank_to_use = min(rank * 2, total_rank)
|
||||
|
||||
if rank_to_use < total_rank:
|
||||
# Use the smallest singular values and vectors
|
||||
selected_U = U[:, -rank_to_use:]
|
||||
selected_S = S[-rank_to_use:]
|
||||
selected_V = V[-rank_to_use:, :]
|
||||
else:
|
||||
# If we'd use all values, just use the standard approach but with a note
|
||||
print("Warning: Requested rank is too large for URAE specialty, using all singular values")
|
||||
selected_U = U
|
||||
selected_S = S
|
||||
selected_V = V
|
||||
|
||||
# Create new LoRA matrices
|
||||
new_up = selected_U @ torch.diag(torch.sqrt(selected_S))
|
||||
new_down = torch.diag(torch.sqrt(selected_S)) @ selected_V
|
||||
|
||||
# These matrices can now be used as standard LoRA weights
|
||||
return new_up, new_down
|
||||
|
||||
with torch.no_grad():
|
||||
progress = tqdm(total=len(loras), desc="Converting")
|
||||
for lora in loras:
|
||||
lora_up_key = f"{lora.lora_name}.lora_up.weight"
|
||||
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
||||
lora_alpha_key = f"{lora.lora_name}.alpha"
|
||||
lora_up = state_dict[lora_up_key]
|
||||
lora_down = state_dict[lora_down_key]
|
||||
if self.initialize[:4] == "urae":
|
||||
up, down = convert_urae_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
||||
up, down, alpha = convert_urae_to_standard_lora(
|
||||
lora_up,
|
||||
lora_down,
|
||||
lora._org_lora_up.to(lora_up.device),
|
||||
lora._org_lora_down.to(lora_up.device),
|
||||
initial_alpha=lora.alpha.item(),
|
||||
rank=lora.lora_dim,
|
||||
)
|
||||
state_dict[lora_alpha_key] = torch.tensor(alpha).to(dtype=lora.alpha.dtype)
|
||||
elif self.initialize[:5] == "pissa":
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
||||
up, down = convert_pissa_to_standard_lora(
|
||||
lora_up,
|
||||
lora_down,
|
||||
lora._org_lora_up.to(lora_up.device),
|
||||
lora._org_lora_down.to(lora_up.device),
|
||||
lora.lora_dim,
|
||||
)
|
||||
|
||||
# TODO: Capture option if we should offload
|
||||
# offload to CPU
|
||||
|
||||
241
tests/library/test_network_utils_urae.py
Normal file
241
tests/library/test_network_utils_urae.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.network_utils import convert_urae_to_standard_lora
|
||||
|
||||
|
||||
class TestConvertURAEToStandardLoRA:
|
||||
@pytest.fixture
|
||||
def sample_matrices(self):
|
||||
"""Create sample matrices for testing"""
|
||||
# Original up matrix (4x2)
|
||||
orig_up = torch.tensor([
|
||||
[1.0, 2.0],
|
||||
[3.0, 4.0],
|
||||
[5.0, 6.0],
|
||||
[7.0, 8.0]
|
||||
])
|
||||
|
||||
# Original down matrix (2x6)
|
||||
orig_down = torch.tensor([
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
[0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
||||
])
|
||||
|
||||
# Trained up matrix (4x2) - same shape as orig_up but with changed values
|
||||
trained_up = torch.tensor([
|
||||
[1.1, 2.1],
|
||||
[3.1, 4.1],
|
||||
[5.1, 6.1],
|
||||
[7.1, 8.1]
|
||||
])
|
||||
|
||||
# Trained down matrix (2x6) - same shape as orig_down but with changed values
|
||||
trained_down = torch.tensor([
|
||||
[0.15, 0.25, 0.35, 0.45, 0.55, 0.65],
|
||||
[0.75, 0.85, 0.95, 1.05, 1.15, 1.25]
|
||||
])
|
||||
|
||||
return {
|
||||
'orig_up': orig_up,
|
||||
'orig_down': orig_down,
|
||||
'trained_up': trained_up,
|
||||
'trained_down': trained_down
|
||||
}
|
||||
|
||||
def test_basic_conversion(self, sample_matrices):
|
||||
"""Test the basic functionality of convert_urae_to_standard_lora"""
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down']
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert lora_up.shape[0] == sample_matrices['trained_up'].shape[0] # Same number of rows as trained_up
|
||||
assert lora_up.shape[1] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
|
||||
assert lora_down.shape[0] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
|
||||
assert lora_down.shape[1] == sample_matrices['trained_down'].shape[1] # Same number of columns as trained_down
|
||||
|
||||
# Check alpha is a reasonable value
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
|
||||
# Check that lora_up @ lora_down approximates the weight delta
|
||||
delta = (sample_matrices['trained_up'] @ sample_matrices['trained_down']) - (sample_matrices['orig_up'] @ sample_matrices['orig_down'])
|
||||
|
||||
# The approximation should be close in Frobenius norm after scaling
|
||||
lora_effect = lora_up @ lora_down
|
||||
delta_norm = torch.norm(delta, p="fro").item()
|
||||
lora_norm = torch.norm(lora_effect, p="fro").item()
|
||||
|
||||
# Either they are close, or the alpha scaling brings them close
|
||||
scaled_lora_effect = (alpha / sample_matrices['trained_up'].shape[1]) * lora_effect
|
||||
scaled_lora_norm = torch.norm(scaled_lora_effect, p="fro").item()
|
||||
|
||||
# At least one of these should be true
|
||||
assert abs(delta_norm - lora_norm) < 1e-4 or abs(delta_norm - scaled_lora_norm) < 1e-4
|
||||
|
||||
def test_specified_rank(self, sample_matrices):
|
||||
"""Test conversion with a specified rank"""
|
||||
new_rank = 1 # Lower than trained_up's rank of 2
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
rank=new_rank
|
||||
)
|
||||
|
||||
# Check that the new rank is used
|
||||
assert lora_up.shape[1] == new_rank
|
||||
assert lora_down.shape[0] == new_rank
|
||||
|
||||
# Should still produce a reasonable alpha
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
|
||||
def test_with_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with a specified initial alpha"""
|
||||
initial_alpha = 16.0
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha
|
||||
)
|
||||
|
||||
# Alpha should be influenced by initial_alpha but may be adjusted
|
||||
# Since we're using same rank, should be reasonably close to initial_alpha
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
# Should at least preserve the order of magnitude in typical cases
|
||||
assert abs(alpha - initial_alpha) <= initial_alpha * 4.0
|
||||
|
||||
def test_large_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with a very large initial alpha that should be capped"""
|
||||
initial_alpha = 2000.0 # Larger than the 1024.0 cap
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha
|
||||
)
|
||||
|
||||
# Alpha should be capped at 1024.0
|
||||
assert alpha <= 1024.0
|
||||
|
||||
def test_very_small_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with a very small initial alpha that should be floored"""
|
||||
initial_alpha = 0.01 # Smaller than the 0.1 floor
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha
|
||||
)
|
||||
|
||||
# Alpha should be floored at 0.1
|
||||
assert alpha >= 0.1
|
||||
|
||||
def test_change_rank_with_initial_alpha(self, sample_matrices):
|
||||
"""Test conversion with both rank change and initial alpha"""
|
||||
initial_alpha = 16.0
|
||||
new_rank = 1 # Half of original rank 2
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
sample_matrices['trained_up'],
|
||||
sample_matrices['trained_down'],
|
||||
sample_matrices['orig_up'],
|
||||
sample_matrices['orig_down'],
|
||||
initial_alpha=initial_alpha,
|
||||
rank=new_rank
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert lora_up.shape[1] == new_rank
|
||||
assert lora_down.shape[0] == new_rank
|
||||
|
||||
# Alpha should be adjusted for the rank change (approx halved in this case)
|
||||
expected_alpha = initial_alpha * (new_rank / sample_matrices['trained_up'].shape[1])
|
||||
# Allow some tolerance for adjustments from norm-based capping
|
||||
assert abs(alpha - expected_alpha) <= expected_alpha * 4.0 or alpha >= 0.1
|
||||
|
||||
def test_zero_delta(self):
|
||||
"""Test conversion when delta is zero"""
|
||||
# Create matrices where the delta will be zero
|
||||
dim_in, rank, dim_out = 4, 2, 6
|
||||
|
||||
# Create identical matrices for original and trained
|
||||
orig_up = torch.randn(dim_in, rank)
|
||||
orig_down = torch.randn(rank, dim_out)
|
||||
trained_up = orig_up.clone()
|
||||
trained_down = orig_down.clone()
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
trained_up,
|
||||
trained_down,
|
||||
orig_up,
|
||||
orig_down
|
||||
)
|
||||
|
||||
# Should still return matrices of correct shape
|
||||
assert lora_up.shape == (dim_in, rank)
|
||||
assert lora_down.shape == (rank, dim_out)
|
||||
|
||||
# Alpha should be at least the minimum
|
||||
assert alpha >= 0.1
|
||||
|
||||
def test_large_dimensions(self):
|
||||
"""Test with larger matrix dimensions"""
|
||||
dim_in, rank, dim_out = 100, 8, 200
|
||||
|
||||
orig_up = torch.randn(dim_in, rank)
|
||||
orig_down = torch.randn(rank, dim_out)
|
||||
trained_up = orig_up + 0.01 * torch.randn(dim_in, rank) # Small perturbation
|
||||
trained_down = orig_down + 0.01 * torch.randn(rank, dim_out) # Small perturbation
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
trained_up,
|
||||
trained_down,
|
||||
orig_up,
|
||||
orig_down
|
||||
)
|
||||
|
||||
# Check shapes
|
||||
assert lora_up.shape == (dim_in, rank)
|
||||
assert lora_down.shape == (rank, dim_out)
|
||||
|
||||
# Should produce a reasonable alpha
|
||||
assert 0.1 <= alpha <= 1024.0
|
||||
|
||||
def test_rank_exceeding_singular_values(self):
|
||||
"""Test when requested rank exceeds available singular values"""
|
||||
# Small matrices with limited rank
|
||||
dim_in, rank, dim_out = 3, 2, 3
|
||||
|
||||
orig_up = torch.randn(dim_in, rank)
|
||||
orig_down = torch.randn(rank, dim_out)
|
||||
trained_up = orig_up + 0.1 * torch.randn(dim_in, rank)
|
||||
trained_down = orig_down + 0.1 * torch.randn(rank, dim_out)
|
||||
|
||||
# Request rank larger than possible
|
||||
too_large_rank = 10
|
||||
|
||||
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||
trained_up,
|
||||
trained_down,
|
||||
orig_up,
|
||||
orig_down,
|
||||
rank=too_large_rank
|
||||
)
|
||||
|
||||
# Rank should be limited to min(dim_in, dim_out, S.size)
|
||||
max_possible_rank = min(dim_in, dim_out)
|
||||
assert lora_up.shape[1] <= max_possible_rank
|
||||
assert lora_down.shape[0] <= max_possible_rank
|
||||
Reference in New Issue
Block a user