mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
- Add pythonpath = . to pytest to get the current directory - Fix device of LoRA after PiSSA initialization to return to proper device
114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
import torch
|
|
import math
|
|
import warnings
|
|
from typing import Optional
|
|
|
|
|
|
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
|
|
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
|
torch.nn.init.zeros_(lora_up.weight)
|
|
|
|
|
|
# URAE: Ultra-Resolution Adaptation with Ease
|
|
def initialize_urae(
|
|
org_module: torch.nn.Module,
|
|
lora_down: torch.nn.Module,
|
|
lora_up: torch.nn.Module,
|
|
scale: float,
|
|
rank: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
org_module_device = org_module.weight.device
|
|
org_module_weight_dtype = org_module.weight.data.dtype
|
|
org_module_requires_grad = org_module.weight.requires_grad
|
|
|
|
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
|
|
|
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
|
|
|
with torch.autocast(device.type):
|
|
# SVD decomposition
|
|
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
|
|
|
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
|
|
Vr = V[:, -rank:]
|
|
Sr = S[-rank:]
|
|
Sr /= rank
|
|
Uhr = Uh[-rank:, :]
|
|
|
|
# Create down and up matrices
|
|
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
|
up = Vr @ torch.diag(torch.sqrt(Sr))
|
|
|
|
# Get expected shapes
|
|
expected_down_shape = lora_down.weight.shape
|
|
expected_up_shape = lora_up.weight.shape
|
|
|
|
# Verify shapes match expected
|
|
if down.shape != expected_down_shape:
|
|
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
|
|
|
if up.shape != expected_up_shape:
|
|
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
|
|
|
# Assign to LoRA weights
|
|
lora_up.weight.data = up
|
|
lora_down.weight.data = down
|
|
|
|
# Optionally, subtract from original weight
|
|
weight = weight - scale * (up @ down)
|
|
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
|
org_module.weight.requires_grad = org_module_requires_grad
|
|
|
|
|
|
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
|
def initialize_pissa(
|
|
org_module: torch.nn.Module,
|
|
lora_down: torch.nn.Module,
|
|
lora_up: torch.nn.Module,
|
|
scale: float,
|
|
rank: int,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
org_module_device = org_module.weight.device
|
|
org_module_weight_dtype = org_module.weight.data.dtype
|
|
org_module_requires_grad = org_module.weight.requires_grad
|
|
|
|
device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"))
|
|
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
|
|
|
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
|
|
|
|
with torch.no_grad():
|
|
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
|
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
|
Vr = V[:, :rank]
|
|
Sr = S[:rank]
|
|
Sr /= rank
|
|
Uhr = Uh[:rank]
|
|
|
|
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
|
up = Vr @ torch.diag(torch.sqrt(Sr))
|
|
|
|
# Get expected shapes
|
|
expected_down_shape = lora_down.weight.shape
|
|
expected_up_shape = lora_up.weight.shape
|
|
|
|
# Verify shapes match expected or reshape appropriately
|
|
if down.shape != expected_down_shape:
|
|
warnings.warn(UserWarning(f"Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
|
|
|
if up.shape != expected_up_shape:
|
|
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
|
|
|
lora_up.weight.data = up.to(lora_up.weight.data.device, dtype=lora_up.weight.dtype)
|
|
lora_down.weight.data = down.to(lora_down.weight.data.device, dtype=lora_down.weight.dtype)
|
|
|
|
weight = weight.data - scale * (up @ down)
|
|
org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype)
|
|
org_module.weight.requires_grad = org_module_requires_grad
|
|
|