mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Remove lowrank for URAE
This commit is contained in:
@@ -3,7 +3,6 @@ import math
|
||||
import warnings
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
from library.incremental_pca import IncrementalPCA
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -24,7 +23,6 @@ def initialize_parse_opts(key: str) -> InitializeParams:
|
||||
- "pissa" -> Default PiSSA with lowrank=True, niter=4
|
||||
- "pissa_niter_4" -> PiSSA with niter=4
|
||||
- "pissa_lowrank_false" -> PiSSA without lowrank
|
||||
- "urae_..." -> Same options but for URAE
|
||||
|
||||
Args:
|
||||
key: String key to parse
|
||||
@@ -80,11 +78,7 @@ def initialize_urae(
|
||||
rank: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
use_ipca: bool = False,
|
||||
use_lowrank: bool = True,
|
||||
lowrank_q: Optional[int] = None,
|
||||
lowrank_niter: int = 4,
|
||||
lowrank_seed: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# Store original device, dtype, and requires_grad status
|
||||
orig_device = org_module.weight.device
|
||||
@@ -98,45 +92,17 @@ def initialize_urae(
|
||||
# Move original weight to chosen device and use float32 for numerical stability
|
||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||
|
||||
with torch.autocast(device.type), torch.no_grad():
|
||||
# Perform SVD decomposition (either directly or with IPCA for memory efficiency)
|
||||
if use_ipca:
|
||||
ipca = IncrementalPCA(
|
||||
n_components=None,
|
||||
batch_size=1024,
|
||||
lowrank=use_lowrank,
|
||||
lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape),
|
||||
lowrank_niter=lowrank_niter,
|
||||
lowrank_seed=lowrank_seed,
|
||||
)
|
||||
ipca.fit(weight)
|
||||
with torch.no_grad():
|
||||
# Direct SVD approach
|
||||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# Extract singular values and vectors, focusing on the minor components (smallest singular values)
|
||||
S_full = ipca.singular_values_
|
||||
V_full = ipca.components_.T # Shape: [out_features, total_rank]
|
||||
# Extract the minor components (smallest singular values)
|
||||
Sr = S[-rank:]
|
||||
Vr = U[:, -rank:]
|
||||
Uhr = Vh[-rank:]
|
||||
|
||||
# Get identity matrix to transform for right singular vectors
|
||||
identity = torch.eye(weight.shape[1], device=weight.device)
|
||||
Uhr_full = ipca.transform(identity).T # Shape: [total_rank, in_features]
|
||||
|
||||
# Extract the last 'rank' components (the minor/smallest ones)
|
||||
Sr = S_full[-rank:]
|
||||
Vr = V_full[:, -rank:]
|
||||
Uhr = Uhr_full[-rank:]
|
||||
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
else:
|
||||
# Direct SVD approach
|
||||
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# Extract the minor components (smallest singular values)
|
||||
Sr = S[-rank:]
|
||||
Vr = U[:, -rank:]
|
||||
Uhr = Vh[-rank:]
|
||||
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
# Scale singular values
|
||||
Sr = Sr / rank
|
||||
|
||||
# Create the low-rank adapter matrices by splitting the minor components
|
||||
# Down matrix: scaled right singular vectors with singular values
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
import torch
|
||||
import pytest
|
||||
from library.network_utils import initialize_pissa
|
||||
|
||||
Reference in New Issue
Block a user