mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Update URAE initialization, conversion. Add tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
|
from torch import Tensor
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from library.incremental_pca import IncrementalPCA
|
from library.incremental_pca import IncrementalPCA
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -108,73 +109,73 @@ def initialize_urae(
|
|||||||
lowrank_niter: int = 4,
|
lowrank_niter: int = 4,
|
||||||
lowrank_seed: Optional[int] = None,
|
lowrank_seed: Optional[int] = None,
|
||||||
):
|
):
|
||||||
org_module_device = org_module.weight.device
|
# Store original device, dtype, and requires_grad status
|
||||||
org_module_weight_dtype = org_module.weight.data.dtype
|
orig_device = org_module.weight.device
|
||||||
org_module_requires_grad = org_module.weight.requires_grad
|
orig_dtype = org_module.weight.data.dtype
|
||||||
|
orig_requires_grad = org_module.weight.requires_grad
|
||||||
|
|
||||||
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
|
# Determine device and dtype to work with
|
||||||
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
||||||
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
|
||||||
|
|
||||||
|
# Move original weight to chosen device and use float32 for numerical stability
|
||||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
|
||||||
if use_ipca:
|
if use_ipca:
|
||||||
# For URAE we need all components to get the "residual" ones
|
|
||||||
ipca = IncrementalPCA(
|
ipca = IncrementalPCA(
|
||||||
n_components=None, # Get all components
|
n_components=None,
|
||||||
batch_size=1024,
|
batch_size=1024,
|
||||||
lowrank=use_lowrank,
|
lowrank=use_lowrank,
|
||||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals
|
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
|
||||||
lowrank_niter=lowrank_niter,
|
lowrank_niter=lowrank_niter,
|
||||||
lowrank_seed=lowrank_seed,
|
lowrank_seed=lowrank_seed,
|
||||||
)
|
)
|
||||||
ipca.fit(weight)
|
ipca.fit(weight)
|
||||||
|
|
||||||
# For URAE, use the LAST/SMALLEST singular values
|
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
|
||||||
total_rank = min(weight.shape[0], weight.shape[1])
|
S_full = ipca.singular_values_
|
||||||
V_full = ipca.components_.T # [out_features, total_rank]
|
V_full = ipca.components_.T # Shape: [out_features, total_rank]
|
||||||
S_full = ipca.singular_values_ # [total_rank]
|
|
||||||
|
|
||||||
# Get the smallest singular values and vectors
|
# Get identity matrix to transform for right singular vectors
|
||||||
Vr = V_full[:, -rank:] # Last rank left singular vectors
|
|
||||||
Sr = S_full[-rank:] # Last rank singular values
|
|
||||||
Sr /= rank
|
|
||||||
|
|
||||||
# To get Uhr (last rank right singular vectors), transform basis vectors
|
|
||||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||||
Uhr_full = ipca.transform(identity).T # [total_rank, in_features]
|
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
|
||||||
Uhr = Uhr_full[-rank:] # Last rank right singular vectors
|
|
||||||
|
# Extract the last 'rank' components (the minor/smallest ones)
|
||||||
|
Sr = S_full[-rank:]
|
||||||
|
Vr = V_full[:, -rank:]
|
||||||
|
Uhr = Uhr_full[-rank:]
|
||||||
|
|
||||||
|
# Scale singular values
|
||||||
|
Sr = Sr / rank
|
||||||
else:
|
else:
|
||||||
# Standard SVD approach
|
# Direct SVD approach
|
||||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||||
Vr = V[:, -rank:]
|
|
||||||
|
# Extract the minor components (smallest singular values)
|
||||||
Sr = S[-rank:]
|
Sr = S[-rank:]
|
||||||
Sr /= rank
|
Vr = U[:, -rank:]
|
||||||
Uhr = Uh[-rank:, :]
|
Uhr = Vh[-rank:]
|
||||||
|
|
||||||
# Create down and up matrices
|
# Scale singular values
|
||||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
Sr = Sr / rank
|
||||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
|
||||||
|
|
||||||
# Get expected shapes
|
# Create the low-rank adapter matrices by splitting the minor components
|
||||||
expected_down_shape = lora_down.weight.shape
|
# Down matrix: scaled right singular vectors with singular values
|
||||||
expected_up_shape = lora_up.weight.shape
|
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||||
|
|
||||||
# Verify shapes match expected
|
# Up matrix: scaled left singular vectors with singular values
|
||||||
if down.shape != expected_down_shape:
|
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
|
||||||
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
|
||||||
|
|
||||||
if up.shape != expected_up_shape:
|
# Assign to LoRA modules
|
||||||
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
lora_down.weight.data = down_matrix.to(device=device, dtype=dtype)
|
||||||
|
lora_up.weight.data = up_matrix.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
# Assign to LoRA weights
|
# Update the original weight by removing the minor components
|
||||||
lora_up.weight.data = up
|
# This is equivalent to keeping only the major components
|
||||||
lora_down.weight.data = down
|
modified_weight = weight - scale * (up_matrix @ down_matrix)
|
||||||
|
org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype)
|
||||||
# Optionally, subtract from original weight
|
org_module.weight.requires_grad = orig_requires_grad
|
||||||
weight = weight - scale * (up @ down)
|
|
||||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
|
||||||
org_module.weight.requires_grad = org_module_requires_grad
|
|
||||||
|
|
||||||
|
|
||||||
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
||||||
@@ -269,3 +270,107 @@ def initialize_pissa(
|
|||||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||||
org_module.weight.requires_grad = org_module_requires_grad
|
org_module.weight.requires_grad = org_module_requires_grad
|
||||||
|
|
||||||
|
|
||||||
|
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||||
|
# Calculate ΔW = A'B' - AB
|
||||||
|
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||||
|
|
||||||
|
# We need to create new low-rank matrices that represent this delta
|
||||||
|
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||||
|
|
||||||
|
# Take the top 2*r singular values (as suggested in the paper)
|
||||||
|
rank = rank * 2
|
||||||
|
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||||
|
|
||||||
|
# Create new LoRA matrices
|
||||||
|
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||||
|
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||||
|
|
||||||
|
# These matrices can now be used as standard LoRA weights
|
||||||
|
return new_up, new_down
|
||||||
|
|
||||||
|
|
||||||
|
def convert_urae_to_standard_lora(
|
||||||
|
trained_up: Tensor,
|
||||||
|
trained_down: Tensor,
|
||||||
|
orig_up: Tensor,
|
||||||
|
orig_down: Tensor,
|
||||||
|
initial_alpha: float | None = None,
|
||||||
|
rank: int | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Convert URAE trained weights to standard LoRA format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trained_up: The trained URAE Up matrix
|
||||||
|
trained_down: The trained URAE Down matrix
|
||||||
|
orig_up: The original up matrix before training
|
||||||
|
orig_down: The original down matrix before training
|
||||||
|
initial_alpha: The alpha value used during URAE training (if any)
|
||||||
|
rank: The rank for the standard LoRA (if None, uses the rank of trained_A)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
lora_up: Standard LoRA up matrix
|
||||||
|
lora_down: Standard LoRA down matrix
|
||||||
|
alpha: Appropriate alpha value for the LoRA
|
||||||
|
"""
|
||||||
|
# Calculate the weight delta
|
||||||
|
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||||
|
|
||||||
|
# Perform SVD on the delta
|
||||||
|
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
|
||||||
|
|
||||||
|
# If rank is not specified, use the same rank as the trained matrices
|
||||||
|
if rank is None:
|
||||||
|
rank = trained_up.shape[1]
|
||||||
|
else:
|
||||||
|
# Ensure we don't exceed available singular values
|
||||||
|
rank = min(rank, len(S))
|
||||||
|
|
||||||
|
# Create standard LoRA matrices using top singular values
|
||||||
|
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
|
||||||
|
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||||
|
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||||
|
|
||||||
|
# Method 1: Preserve the Frobenius norm of the delta
|
||||||
|
original_effect: float = torch.norm(delta_w, p="fro").item()
|
||||||
|
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
|
||||||
|
|
||||||
|
# The scaling factor in lora is (alpha/r), so:
|
||||||
|
# alpha/r × ||AB|| = ||delta_W||
|
||||||
|
# alpha = r × ||delta_W|| / ||AB||
|
||||||
|
if unscaled_lora_effect > 0:
|
||||||
|
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
|
||||||
|
else:
|
||||||
|
norm_based_alpha = 1.0 # Fallback
|
||||||
|
|
||||||
|
# Method 2: If initial_alpha is provided, adjust based on rank change
|
||||||
|
if initial_alpha is not None:
|
||||||
|
initial_rank = trained_up.shape[1]
|
||||||
|
# Scale alpha proportionally if rank changed
|
||||||
|
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
|
||||||
|
else:
|
||||||
|
rank_adjusted_alpha = None
|
||||||
|
|
||||||
|
# Choose the appropriate alpha
|
||||||
|
if rank_adjusted_alpha is not None:
|
||||||
|
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
|
||||||
|
# Cap the difference to avoid extreme values
|
||||||
|
alpha = rank_adjusted_alpha
|
||||||
|
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
|
||||||
|
if norm_based_alpha > 0:
|
||||||
|
max_factor = 5.0 # Allow up to 5x difference
|
||||||
|
upper_bound = norm_based_alpha * max_factor
|
||||||
|
lower_bound = norm_based_alpha / max_factor
|
||||||
|
alpha = min(max(alpha, lower_bound), upper_bound)
|
||||||
|
else:
|
||||||
|
# Use norm-based alpha
|
||||||
|
alpha = norm_based_alpha
|
||||||
|
|
||||||
|
# Round to a clean value for better usability
|
||||||
|
alpha = round(alpha, 2)
|
||||||
|
|
||||||
|
# Ensure alpha is positive and within reasonable bounds
|
||||||
|
alpha = max(0.1, min(alpha, 1024.0))
|
||||||
|
|
||||||
|
return lora_up, lora_down, alpha
|
||||||
|
|||||||
@@ -19,7 +19,14 @@ from tqdm import tqdm
|
|||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
from library.device_utils import clean_memory_on_device
|
from library.device_utils import clean_memory_on_device
|
||||||
from library.network_utils import initialize_lora, initialize_pissa, initialize_urae, initialize_parse_opts
|
from library.network_utils import (
|
||||||
|
initialize_lora,
|
||||||
|
initialize_pissa,
|
||||||
|
initialize_urae,
|
||||||
|
initialize_parse_opts,
|
||||||
|
convert_pissa_to_standard_lora,
|
||||||
|
convert_urae_to_standard_lora,
|
||||||
|
)
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
@@ -49,6 +56,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
split_dims: Optional[List[int]] = None,
|
split_dims: Optional[List[int]] = None,
|
||||||
ggpo_beta: Optional[float] = None,
|
ggpo_beta: Optional[float] = None,
|
||||||
ggpo_sigma: Optional[float] = None,
|
ggpo_sigma: Optional[float] = None,
|
||||||
|
initialize: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
if alpha == 0 or None, alpha is rank (no scaling).
|
if alpha == 0 or None, alpha is rank (no scaling).
|
||||||
@@ -101,7 +109,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.initialize_norm_cache(org_module.weight)
|
self.initialize_norm_cache(org_module.weight)
|
||||||
self.org_module_shape: tuple[int] = org_module.weight.shape
|
self.org_module_shape: tuple[int] = org_module.weight.shape
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
|
def initialize_weights(self, org_module: torch.nn.Module, initialize: Optional[str], device: Optional[torch.device]):
|
||||||
"""
|
"""
|
||||||
Initialize the weights for the LoRA
|
Initialize the weights for the LoRA
|
||||||
@@ -113,12 +120,16 @@ class LoRAModule(torch.nn.Module):
|
|||||||
if initialize is not None:
|
if initialize is not None:
|
||||||
params = initialize_parse_opts(initialize)
|
params = initialize_parse_opts(initialize)
|
||||||
if initialize[:4] == "urae":
|
if initialize[:4] == "urae":
|
||||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
|
initialize_urae(
|
||||||
|
org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)
|
||||||
|
)
|
||||||
# Need to store the original weights so we can get a plain LoRA out
|
# Need to store the original weights so we can get a plain LoRA out
|
||||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||||
elif initialize[:5] == "pissa":
|
elif initialize[:5] == "pissa":
|
||||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params))
|
initialize_pissa(
|
||||||
|
org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)
|
||||||
|
)
|
||||||
# Need to store the original weights so we can get a plain LoRA out
|
# Need to store the original weights so we can get a plain LoRA out
|
||||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||||
@@ -149,7 +160,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self._org_lora_up = self._org_lora_up.to("cpu")
|
self._org_lora_up = self._org_lora_up.to("cpu")
|
||||||
self._org_lora_down = self._org_lora_down.to("cpu")
|
self._org_lora_down = self._org_lora_down.to("cpu")
|
||||||
|
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
@@ -607,7 +617,7 @@ def create_network(
|
|||||||
if verbose is not None:
|
if verbose is not None:
|
||||||
verbose = True if verbose == "True" else False
|
verbose = True if verbose == "True" else False
|
||||||
|
|
||||||
# Computation device, used in initialization
|
# Computation device, used in initialization
|
||||||
comp_device = kwargs.get("comp_device", None)
|
comp_device = kwargs.get("comp_device", None)
|
||||||
if comp_device is not None:
|
if comp_device is not None:
|
||||||
comp_device = torch.device(comp_device)
|
comp_device = torch.device(comp_device)
|
||||||
@@ -922,7 +932,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
elif "single" in lora_name and "linear1" in lora_name:
|
elif "single" in lora_name and "linear1" in lora_name:
|
||||||
split_dims = [3072] * 3 + [12288]
|
split_dims = [3072] * 3 + [12288]
|
||||||
|
|
||||||
assert module_class is LoRAModule or module_class is LoRAInfModule, f"Module class is not valid {type(module_class)}"
|
assert module_class is LoRAModule or module_class is LoRAInfModule, (
|
||||||
|
f"Module class is not valid {type(module_class)}"
|
||||||
|
)
|
||||||
lora = module_class(
|
lora = module_class(
|
||||||
lora_name,
|
lora_name,
|
||||||
child_module,
|
child_module,
|
||||||
@@ -960,7 +972,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"Initialize Text Encoder LoRA using {initialize}")
|
logger.info(f"Initialize Text Encoder LoRA using {initialize}")
|
||||||
|
|
||||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||||
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index+1}.")
|
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index + 1}.")
|
||||||
self.text_encoder_loras.extend(text_encoder_loras)
|
self.text_encoder_loras.extend(text_encoder_loras)
|
||||||
skipped_te += skipped
|
skipped_te += skipped
|
||||||
|
|
||||||
@@ -1313,66 +1325,33 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
# Need to decompose the parameters into a LoRA format
|
# Need to decompose the parameters into a LoRA format
|
||||||
if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"):
|
if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"):
|
||||||
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
|
loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras
|
||||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
|
||||||
# Calculate ΔW = A'B' - AB
|
|
||||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
|
||||||
|
|
||||||
# We need to create new low-rank matrices that represent this delta
|
|
||||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
|
||||||
|
|
||||||
# Take the top 2*r singular values (as suggested in the paper)
|
|
||||||
rank = rank * 2
|
|
||||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
|
||||||
|
|
||||||
# Create new LoRA matrices
|
|
||||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
|
||||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
|
||||||
|
|
||||||
# These matrices can now be used as standard LoRA weights
|
|
||||||
return new_up, new_down
|
|
||||||
|
|
||||||
def convert_urae_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
|
||||||
# Calculate ΔW = A'B' - AB
|
|
||||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
|
||||||
|
|
||||||
# We need to create new low-rank matrices that represent this delta
|
|
||||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
|
||||||
|
|
||||||
# For URAE, we want to focus on the smallest singular values
|
|
||||||
# Take the bottom rank*2 singular values (opposite of PiSSA which takes the top ones)
|
|
||||||
total_rank = len(S)
|
|
||||||
rank_to_use = min(rank * 2, total_rank)
|
|
||||||
|
|
||||||
if rank_to_use < total_rank:
|
|
||||||
# Use the smallest singular values and vectors
|
|
||||||
selected_U = U[:, -rank_to_use:]
|
|
||||||
selected_S = S[-rank_to_use:]
|
|
||||||
selected_V = V[-rank_to_use:, :]
|
|
||||||
else:
|
|
||||||
# If we'd use all values, just use the standard approach but with a note
|
|
||||||
print("Warning: Requested rank is too large for URAE specialty, using all singular values")
|
|
||||||
selected_U = U
|
|
||||||
selected_S = S
|
|
||||||
selected_V = V
|
|
||||||
|
|
||||||
# Create new LoRA matrices
|
|
||||||
new_up = selected_U @ torch.diag(torch.sqrt(selected_S))
|
|
||||||
new_down = torch.diag(torch.sqrt(selected_S)) @ selected_V
|
|
||||||
|
|
||||||
# These matrices can now be used as standard LoRA weights
|
|
||||||
return new_up, new_down
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
progress = tqdm(total=len(loras), desc="Converting")
|
progress = tqdm(total=len(loras), desc="Converting")
|
||||||
for lora in loras:
|
for lora in loras:
|
||||||
lora_up_key = f"{lora.lora_name}.lora_up.weight"
|
lora_up_key = f"{lora.lora_name}.lora_up.weight"
|
||||||
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
||||||
|
lora_alpha_key = f"{lora.lora_name}.alpha"
|
||||||
lora_up = state_dict[lora_up_key]
|
lora_up = state_dict[lora_up_key]
|
||||||
lora_down = state_dict[lora_down_key]
|
lora_down = state_dict[lora_down_key]
|
||||||
if self.initialize[:4] == "urae":
|
if self.initialize[:4] == "urae":
|
||||||
up, down = convert_urae_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
up, down, alpha = convert_urae_to_standard_lora(
|
||||||
|
lora_up,
|
||||||
|
lora_down,
|
||||||
|
lora._org_lora_up.to(lora_up.device),
|
||||||
|
lora._org_lora_down.to(lora_up.device),
|
||||||
|
initial_alpha=lora.alpha.item(),
|
||||||
|
rank=lora.lora_dim,
|
||||||
|
)
|
||||||
|
state_dict[lora_alpha_key] = torch.tensor(alpha).to(dtype=lora.alpha.dtype)
|
||||||
elif self.initialize[:5] == "pissa":
|
elif self.initialize[:5] == "pissa":
|
||||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim)
|
up, down = convert_pissa_to_standard_lora(
|
||||||
|
lora_up,
|
||||||
|
lora_down,
|
||||||
|
lora._org_lora_up.to(lora_up.device),
|
||||||
|
lora._org_lora_down.to(lora_up.device),
|
||||||
|
lora.lora_dim,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Capture option if we should offload
|
# TODO: Capture option if we should offload
|
||||||
# offload to CPU
|
# offload to CPU
|
||||||
|
|||||||
241
tests/library/test_network_utils_urae.py
Normal file
241
tests/library/test_network_utils_urae.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from library.network_utils import convert_urae_to_standard_lora
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertURAEToStandardLoRA:
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_matrices(self):
|
||||||
|
"""Create sample matrices for testing"""
|
||||||
|
# Original up matrix (4x2)
|
||||||
|
orig_up = torch.tensor([
|
||||||
|
[1.0, 2.0],
|
||||||
|
[3.0, 4.0],
|
||||||
|
[5.0, 6.0],
|
||||||
|
[7.0, 8.0]
|
||||||
|
])
|
||||||
|
|
||||||
|
# Original down matrix (2x6)
|
||||||
|
orig_down = torch.tensor([
|
||||||
|
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
[0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
|
||||||
|
])
|
||||||
|
|
||||||
|
# Trained up matrix (4x2) - same shape as orig_up but with changed values
|
||||||
|
trained_up = torch.tensor([
|
||||||
|
[1.1, 2.1],
|
||||||
|
[3.1, 4.1],
|
||||||
|
[5.1, 6.1],
|
||||||
|
[7.1, 8.1]
|
||||||
|
])
|
||||||
|
|
||||||
|
# Trained down matrix (2x6) - same shape as orig_down but with changed values
|
||||||
|
trained_down = torch.tensor([
|
||||||
|
[0.15, 0.25, 0.35, 0.45, 0.55, 0.65],
|
||||||
|
[0.75, 0.85, 0.95, 1.05, 1.15, 1.25]
|
||||||
|
])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'orig_up': orig_up,
|
||||||
|
'orig_down': orig_down,
|
||||||
|
'trained_up': trained_up,
|
||||||
|
'trained_down': trained_down
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_basic_conversion(self, sample_matrices):
|
||||||
|
"""Test the basic functionality of convert_urae_to_standard_lora"""
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
sample_matrices['trained_up'],
|
||||||
|
sample_matrices['trained_down'],
|
||||||
|
sample_matrices['orig_up'],
|
||||||
|
sample_matrices['orig_down']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check shapes
|
||||||
|
assert lora_up.shape[0] == sample_matrices['trained_up'].shape[0] # Same number of rows as trained_up
|
||||||
|
assert lora_up.shape[1] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
|
||||||
|
assert lora_down.shape[0] == sample_matrices['trained_up'].shape[1] # Same rank as trained_up
|
||||||
|
assert lora_down.shape[1] == sample_matrices['trained_down'].shape[1] # Same number of columns as trained_down
|
||||||
|
|
||||||
|
# Check alpha is a reasonable value
|
||||||
|
assert 0.1 <= alpha <= 1024.0
|
||||||
|
|
||||||
|
# Check that lora_up @ lora_down approximates the weight delta
|
||||||
|
delta = (sample_matrices['trained_up'] @ sample_matrices['trained_down']) - (sample_matrices['orig_up'] @ sample_matrices['orig_down'])
|
||||||
|
|
||||||
|
# The approximation should be close in Frobenius norm after scaling
|
||||||
|
lora_effect = lora_up @ lora_down
|
||||||
|
delta_norm = torch.norm(delta, p="fro").item()
|
||||||
|
lora_norm = torch.norm(lora_effect, p="fro").item()
|
||||||
|
|
||||||
|
# Either they are close, or the alpha scaling brings them close
|
||||||
|
scaled_lora_effect = (alpha / sample_matrices['trained_up'].shape[1]) * lora_effect
|
||||||
|
scaled_lora_norm = torch.norm(scaled_lora_effect, p="fro").item()
|
||||||
|
|
||||||
|
# At least one of these should be true
|
||||||
|
assert abs(delta_norm - lora_norm) < 1e-4 or abs(delta_norm - scaled_lora_norm) < 1e-4
|
||||||
|
|
||||||
|
def test_specified_rank(self, sample_matrices):
|
||||||
|
"""Test conversion with a specified rank"""
|
||||||
|
new_rank = 1 # Lower than trained_up's rank of 2
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
sample_matrices['trained_up'],
|
||||||
|
sample_matrices['trained_down'],
|
||||||
|
sample_matrices['orig_up'],
|
||||||
|
sample_matrices['orig_down'],
|
||||||
|
rank=new_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that the new rank is used
|
||||||
|
assert lora_up.shape[1] == new_rank
|
||||||
|
assert lora_down.shape[0] == new_rank
|
||||||
|
|
||||||
|
# Should still produce a reasonable alpha
|
||||||
|
assert 0.1 <= alpha <= 1024.0
|
||||||
|
|
||||||
|
def test_with_initial_alpha(self, sample_matrices):
|
||||||
|
"""Test conversion with a specified initial alpha"""
|
||||||
|
initial_alpha = 16.0
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
sample_matrices['trained_up'],
|
||||||
|
sample_matrices['trained_down'],
|
||||||
|
sample_matrices['orig_up'],
|
||||||
|
sample_matrices['orig_down'],
|
||||||
|
initial_alpha=initial_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
# Alpha should be influenced by initial_alpha but may be adjusted
|
||||||
|
# Since we're using same rank, should be reasonably close to initial_alpha
|
||||||
|
assert 0.1 <= alpha <= 1024.0
|
||||||
|
# Should at least preserve the order of magnitude in typical cases
|
||||||
|
assert abs(alpha - initial_alpha) <= initial_alpha * 4.0
|
||||||
|
|
||||||
|
def test_large_initial_alpha(self, sample_matrices):
|
||||||
|
"""Test conversion with a very large initial alpha that should be capped"""
|
||||||
|
initial_alpha = 2000.0 # Larger than the 1024.0 cap
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
sample_matrices['trained_up'],
|
||||||
|
sample_matrices['trained_down'],
|
||||||
|
sample_matrices['orig_up'],
|
||||||
|
sample_matrices['orig_down'],
|
||||||
|
initial_alpha=initial_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
# Alpha should be capped at 1024.0
|
||||||
|
assert alpha <= 1024.0
|
||||||
|
|
||||||
|
def test_very_small_initial_alpha(self, sample_matrices):
|
||||||
|
"""Test conversion with a very small initial alpha that should be floored"""
|
||||||
|
initial_alpha = 0.01 # Smaller than the 0.1 floor
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
sample_matrices['trained_up'],
|
||||||
|
sample_matrices['trained_down'],
|
||||||
|
sample_matrices['orig_up'],
|
||||||
|
sample_matrices['orig_down'],
|
||||||
|
initial_alpha=initial_alpha
|
||||||
|
)
|
||||||
|
|
||||||
|
# Alpha should be floored at 0.1
|
||||||
|
assert alpha >= 0.1
|
||||||
|
|
||||||
|
def test_change_rank_with_initial_alpha(self, sample_matrices):
|
||||||
|
"""Test conversion with both rank change and initial alpha"""
|
||||||
|
initial_alpha = 16.0
|
||||||
|
new_rank = 1 # Half of original rank 2
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
sample_matrices['trained_up'],
|
||||||
|
sample_matrices['trained_down'],
|
||||||
|
sample_matrices['orig_up'],
|
||||||
|
sample_matrices['orig_down'],
|
||||||
|
initial_alpha=initial_alpha,
|
||||||
|
rank=new_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check shapes
|
||||||
|
assert lora_up.shape[1] == new_rank
|
||||||
|
assert lora_down.shape[0] == new_rank
|
||||||
|
|
||||||
|
# Alpha should be adjusted for the rank change (approx halved in this case)
|
||||||
|
expected_alpha = initial_alpha * (new_rank / sample_matrices['trained_up'].shape[1])
|
||||||
|
# Allow some tolerance for adjustments from norm-based capping
|
||||||
|
assert abs(alpha - expected_alpha) <= expected_alpha * 4.0 or alpha >= 0.1
|
||||||
|
|
||||||
|
def test_zero_delta(self):
|
||||||
|
"""Test conversion when delta is zero"""
|
||||||
|
# Create matrices where the delta will be zero
|
||||||
|
dim_in, rank, dim_out = 4, 2, 6
|
||||||
|
|
||||||
|
# Create identical matrices for original and trained
|
||||||
|
orig_up = torch.randn(dim_in, rank)
|
||||||
|
orig_down = torch.randn(rank, dim_out)
|
||||||
|
trained_up = orig_up.clone()
|
||||||
|
trained_down = orig_down.clone()
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
trained_up,
|
||||||
|
trained_down,
|
||||||
|
orig_up,
|
||||||
|
orig_down
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still return matrices of correct shape
|
||||||
|
assert lora_up.shape == (dim_in, rank)
|
||||||
|
assert lora_down.shape == (rank, dim_out)
|
||||||
|
|
||||||
|
# Alpha should be at least the minimum
|
||||||
|
assert alpha >= 0.1
|
||||||
|
|
||||||
|
def test_large_dimensions(self):
|
||||||
|
"""Test with larger matrix dimensions"""
|
||||||
|
dim_in, rank, dim_out = 100, 8, 200
|
||||||
|
|
||||||
|
orig_up = torch.randn(dim_in, rank)
|
||||||
|
orig_down = torch.randn(rank, dim_out)
|
||||||
|
trained_up = orig_up + 0.01 * torch.randn(dim_in, rank) # Small perturbation
|
||||||
|
trained_down = orig_down + 0.01 * torch.randn(rank, dim_out) # Small perturbation
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
trained_up,
|
||||||
|
trained_down,
|
||||||
|
orig_up,
|
||||||
|
orig_down
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check shapes
|
||||||
|
assert lora_up.shape == (dim_in, rank)
|
||||||
|
assert lora_down.shape == (rank, dim_out)
|
||||||
|
|
||||||
|
# Should produce a reasonable alpha
|
||||||
|
assert 0.1 <= alpha <= 1024.0
|
||||||
|
|
||||||
|
def test_rank_exceeding_singular_values(self):
|
||||||
|
"""Test when requested rank exceeds available singular values"""
|
||||||
|
# Small matrices with limited rank
|
||||||
|
dim_in, rank, dim_out = 3, 2, 3
|
||||||
|
|
||||||
|
orig_up = torch.randn(dim_in, rank)
|
||||||
|
orig_down = torch.randn(rank, dim_out)
|
||||||
|
trained_up = orig_up + 0.1 * torch.randn(dim_in, rank)
|
||||||
|
trained_down = orig_down + 0.1 * torch.randn(rank, dim_out)
|
||||||
|
|
||||||
|
# Request rank larger than possible
|
||||||
|
too_large_rank = 10
|
||||||
|
|
||||||
|
lora_up, lora_down, alpha = convert_urae_to_standard_lora(
|
||||||
|
trained_up,
|
||||||
|
trained_down,
|
||||||
|
orig_up,
|
||||||
|
orig_down,
|
||||||
|
rank=too_large_rank
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rank should be limited to min(dim_in, dim_out, S.size)
|
||||||
|
max_possible_rank = min(dim_in, dim_out)
|
||||||
|
assert lora_up.shape[1] <= max_possible_rank
|
||||||
|
assert lora_down.shape[0] <= max_possible_rank
|
||||||
Reference in New Issue
Block a user