Make sure on better device (cuda if available) for initialization

This commit is contained in:
rockerBOO
2025-03-25 20:43:10 -04:00
parent 54d4de0e72
commit da47d17231
2 changed files with 79 additions and 46 deletions

View File

@@ -3,57 +3,77 @@ import math
import warnings
from typing import Optional
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(lora_up.weight)
# URAE: Ultra-Resolution Adaptation with Ease
def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
device = device if device is not None else lora_down.weight.data.device
def initialize_urae(
org_module: torch.nn.Module,
lora_down: torch.nn.Module,
lora_up: torch.nn.Module,
scale: float,
rank: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
assert isinstance(device, torch.device), f"Invalid device type: {device}"
weight = org_module.weight.data.to(device, dtype=torch.float32)
with torch.autocast(device.type):
# SVD decomposition
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
Vr = V[:, -rank:]
Sr = S[-rank:]
Sr /= rank
Uhr = Uh[-rank:, :]
# Create down and up matrices
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
# Get expected shapes
expected_down_shape = lora_down.weight.shape
expected_up_shape = lora_up.weight.shape
# Verify shapes match expected
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
# Assign to LoRA weights
lora_up.weight.data = up
lora_down.weight.data = down
# Optionally, subtract from original weight
weight = weight - scale * (up @ down)
weight_dtype = org_module.weight.data.dtype
org_module.weight.data = weight.to(dtype=weight_dtype)
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
weight_dtype = org_module.weight.data.dtype
device = device if device is not None else lora_down.weight.data.device
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
def initialize_pissa(
org_module: torch.nn.Module,
lora_down: torch.nn.Module,
lora_up: torch.nn.Module,
scale: float,
rank: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
weight_dtype = org_module.weight.data.dtype
org_module_requires_grad = org_module.weight.data.requires_grad
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
assert isinstance(device, torch.device), f"Invalid device type: {device}"
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
@@ -61,10 +81,10 @@ def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lo
with torch.autocast(device.type):
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
Vr = V[:, : rank]
Sr = S[: rank]
Vr = V[:, :rank]
Sr = S[:rank]
Sr /= rank
Uhr = Uh[: rank]
Uhr = Uh[:rank]
down = torch.diag(torch.sqrt(Sr)) @ Uhr
up = Vr @ torch.diag(torch.sqrt(Sr))
@@ -76,7 +96,7 @@ def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lo
# Verify shapes match expected or reshape appropriately
if down.shape != expected_down_shape:
warnings.warn(UserWarning(f"Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
if up.shape != expected_up_shape:
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
@@ -85,3 +105,4 @@ def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lo
weight = weight.data - scale * (up @ down)
org_module.weight.data = weight.to(dtype=weight_dtype)
org_module.weight.data.requires_grad = org_module_requires_grad