mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
Update URAE initialization, conversion. Add tests
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import math
|
||||
import warnings
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from library.incremental_pca import IncrementalPCA
|
||||
from dataclasses import dataclass
|
||||
@@ -108,73 +109,73 @@ def initialize_urae(
|
||||
lowrank_niter: int = 4,
|
||||
lowrank_seed: Optional[int] = None,
|
||||
):
|
||||
org_module_device = org_module.weight.device
|
||||
org_module_weight_dtype = org_module.weight.data.dtype
|
||||
org_module_requires_grad = org_module.weight.requires_grad
|
||||
# Store original device, dtype, and requires_grad status
|
||||
orig_device = org_module.weight.device
|
||||
orig_dtype = org_module.weight.data.dtype
|
||||
orig_requires_grad = org_module.weight.requires_grad
|
||||
|
||||
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
|
||||
# Determine device and dtype to work with
|
||||
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
||||
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
||||
dtype = dtype if dtype is not None else lora_down.weight.data.dtype
|
||||
|
||||
# Move original weight to chosen device and use float32 for numerical stability
|
||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||
|
||||
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
|
||||
if use_ipca:
|
||||
# For URAE we need all components to get the "residual" ones
|
||||
ipca = IncrementalPCA(
|
||||
n_components=None, # Get all components
|
||||
n_components=None,
|
||||
batch_size=1024,
|
||||
lowrank=use_lowrank,
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
|
||||
lowrank_niter=lowrank_niter,
|
||||
lowrank_seed=lowrank_seed,
|
||||
)
|
||||
ipca.fit(weight)
|
||||
|
||||
# For URAE, use the LAST/SMALLEST singular values
|
||||
total_rank = min(weight.shape[0], weight.shape[1])
|
||||
V_full = ipca.components_.T # [out_features, total_rank]
|
||||
S_full = ipca.singular_values_ # [total_rank]
|
||||
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
|
||||
S_full = ipca.singular_values_
|
||||
V_full = ipca.components_.T # Shape: [out_features, total_rank]
|
||||
|
||||
# Get the smallest singular values and vectors
|
||||
Vr = V_full[:, -rank:] # Last rank left singular vectors
|
||||
Sr = S_full[-rank:] # Last rank singular values
|
||||
Sr /= rank
|
||||
|
||||
# To get Uhr (last rank right singular vectors), transform basis vectors
|
||||
# Get identity matrix to transform for right singular vectors
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
Uhr_full = ipca.transform(identity).T # [total_rank, in_features]
|
||||
Uhr = Uhr_full[-rank:] # Last rank right singular vectors
|
||||
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
|
||||
|
||||
# Extract the last 'rank' components (the minor/smallest ones)
|
||||
Sr = S_full[-rank:]
|
||||
Vr = V_full[:, -rank:]
|
||||
Uhr = Uhr_full[-rank:]
|
||||
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
else:
|
||||
# Standard SVD approach
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, -rank:]
|
||||
# Direct SVD approach
|
||||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# Extract the minor components (smallest singular values)
|
||||
Sr = S[-rank:]
|
||||
Sr /= rank
|
||||
Uhr = Uh[-rank:, :]
|
||||
Vr = U[:, -rank:]
|
||||
Uhr = Vh[-rank:]
|
||||
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
# Create the low-rank adapter matrices by splitting the minor components
|
||||
# Down matrix: scaled right singular vectors with singular values
|
||||
down_matrix = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
|
||||
# Verify shapes match expected
|
||||
if down.shape != expected_down_shape:
|
||||
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
||||
# Up matrix: scaled left singular vectors with singular values
|
||||
up_matrix = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
# Assign to LoRA modules
|
||||
lora_down.weight.data = down_matrix.to(device=device, dtype=dtype)
|
||||
lora_up.weight.data = up_matrix.to(device=device, dtype=dtype)
|
||||
|
||||
# Assign to LoRA weights
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
# Optionally, subtract from original weight
|
||||
weight = weight - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
org_module.weight.requires_grad = org_module_requires_grad
|
||||
# Update the original weight by removing the minor components
|
||||
# This is equivalent to keeping only the major components
|
||||
modified_weight = weight - scale * (up_matrix @ down_matrix)
|
||||
org_module.weight.data = modified_weight.to(device=orig_device, dtype=orig_dtype)
|
||||
org_module.weight.requires_grad = orig_requires_grad
|
||||
|
||||
|
||||
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
||||
@@ -269,3 +270,107 @@ def initialize_pissa(
|
||||
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
||||
org_module.weight.requires_grad = org_module_requires_grad
|
||||
|
||||
|
||||
def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int):
|
||||
# Calculate ΔW = A'B' - AB
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
rank = min(rank, len(S)) # Make sure we don't exceed available singular values
|
||||
|
||||
# Create new LoRA matrices
|
||||
new_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
new_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# These matrices can now be used as standard LoRA weights
|
||||
return new_up, new_down
|
||||
|
||||
|
||||
def convert_urae_to_standard_lora(
|
||||
trained_up: Tensor,
|
||||
trained_down: Tensor,
|
||||
orig_up: Tensor,
|
||||
orig_down: Tensor,
|
||||
initial_alpha: float | None = None,
|
||||
rank: int | None = None,
|
||||
):
|
||||
"""
|
||||
Convert URAE trained weights to standard LoRA format
|
||||
|
||||
Args:
|
||||
trained_up: The trained URAE Up matrix
|
||||
trained_down: The trained URAE Down matrix
|
||||
orig_up: The original up matrix before training
|
||||
orig_down: The original down matrix before training
|
||||
initial_alpha: The alpha value used during URAE training (if any)
|
||||
rank: The rank for the standard LoRA (if None, uses the rank of trained_A)
|
||||
|
||||
Returns:
|
||||
lora_up: Standard LoRA up matrix
|
||||
lora_down: Standard LoRA down matrix
|
||||
alpha: Appropriate alpha value for the LoRA
|
||||
"""
|
||||
# Calculate the weight delta
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# Perform SVD on the delta
|
||||
U, S, V = torch.linalg.svd(delta_w.to(dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# If rank is not specified, use the same rank as the trained matrices
|
||||
if rank is None:
|
||||
rank = trained_up.shape[1]
|
||||
else:
|
||||
# Ensure we don't exceed available singular values
|
||||
rank = min(rank, len(S))
|
||||
|
||||
# Create standard LoRA matrices using top singular values
|
||||
# This is now standard LoRA (using top values), not URAE (which used bottom values during training)
|
||||
lora_up = U[:, :rank] @ torch.diag(torch.sqrt(S[:rank]))
|
||||
lora_down = torch.diag(torch.sqrt(S[:rank])) @ V[:rank, :]
|
||||
|
||||
# Method 1: Preserve the Frobenius norm of the delta
|
||||
original_effect: float = torch.norm(delta_w, p="fro").item()
|
||||
unscaled_lora_effect: float = torch.norm(lora_up @ lora_down, p="fro").item()
|
||||
|
||||
# The scaling factor in lora is (alpha/r), so:
|
||||
# alpha/r × ||AB|| = ||delta_W||
|
||||
# alpha = r × ||delta_W|| / ||AB||
|
||||
if unscaled_lora_effect > 0:
|
||||
norm_based_alpha = rank * (original_effect / unscaled_lora_effect)
|
||||
else:
|
||||
norm_based_alpha = 1.0 # Fallback
|
||||
|
||||
# Method 2: If initial_alpha is provided, adjust based on rank change
|
||||
if initial_alpha is not None:
|
||||
initial_rank = trained_up.shape[1]
|
||||
# Scale alpha proportionally if rank changed
|
||||
rank_adjusted_alpha = initial_alpha * (rank / initial_rank)
|
||||
else:
|
||||
rank_adjusted_alpha = None
|
||||
|
||||
# Choose the appropriate alpha
|
||||
if rank_adjusted_alpha is not None:
|
||||
# Use the rank-adjusted alpha, but ensure it's not too different from norm-based
|
||||
# Cap the difference to avoid extreme values
|
||||
alpha = rank_adjusted_alpha
|
||||
# Optional: Cap alpha to be within a reasonable range of norm_based_alpha
|
||||
if norm_based_alpha > 0:
|
||||
max_factor = 5.0 # Allow up to 5x difference
|
||||
upper_bound = norm_based_alpha * max_factor
|
||||
lower_bound = norm_based_alpha / max_factor
|
||||
alpha = min(max(alpha, lower_bound), upper_bound)
|
||||
else:
|
||||
# Use norm-based alpha
|
||||
alpha = norm_based_alpha
|
||||
|
||||
# Round to a clean value for better usability
|
||||
alpha = round(alpha, 2)
|
||||
|
||||
# Ensure alpha is positive and within reasonable bounds
|
||||
alpha = max(0.1, min(alpha, 1024.0))
|
||||
|
||||
return lora_up, lora_down, alpha
|
||||
|
||||
Reference in New Issue
Block a user