diff --git a/library/network_utils.py b/library/network_utils.py index 52ab0e37..ca9f836e 100644 --- a/library/network_utils.py +++ b/library/network_utils.py @@ -2,6 +2,90 @@ import torch import math import warnings from typing import Optional +from library.incremental_pca import IncrementalPCA +from dataclasses import dataclass + + +@dataclass +class InitializeParams: + """Parameters for initialization methods (PiSSA, URAE)""" + + use_ipca: bool = False + use_lowrank: bool = True + lowrank_q: Optional[int] = None + lowrank_niter: int = 4 + lowrank_seed: Optional[int] = None + + +def initialize_parse_opts(key: str) -> InitializeParams: + """ + Parse initialization parameters from a string key. + + Format examples: + - "pissa" -> Default PiSSA with lowrank=True, niter=4 + - "pissa_niter_4" -> PiSSA with niter=4 + - "pissa_lowrank_false" -> PiSSA without lowrank + - "pissa_ipca_true" -> PiSSA with IPCA + - "pissa_q_16" -> PiSSA with lowrank_q=16 + - "pissa_seed_42" -> PiSSA with seed=42 + - "urae_..." -> Same options but for URAE + + Args: + key: String key to parse + + Returns: + InitializeParams object with parsed parameters + """ + parts = key.lower().split("_") + + # Extract the method (first part) + method = parts[0] + if method not in ["pissa", "urae"]: + raise ValueError(f"Unknown initialization method: {method}") + + # Start with default parameters + params = InitializeParams() + + # Parse the remaining parts + i = 1 + while i < len(parts): + if parts[i] == "ipca": + if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]: + params.use_ipca = parts[i + 1] == "true" + i += 2 + else: + params.use_ipca = True + i += 1 + elif parts[i] == "lowrank": + if i + 1 < len(parts) and parts[i + 1] in ["true", "false"]: + params.use_lowrank = parts[i + 1] == "true" + i += 2 + else: + params.use_lowrank = True + i += 1 + elif parts[i] == "niter": + if i + 1 < len(parts) and parts[i + 1].isdigit(): + params.lowrank_niter = int(parts[i + 1]) + i += 2 + else: + i += 1 + elif parts[i] == "q": + if i + 1 < len(parts) and parts[i + 1].isdigit(): + params.lowrank_q = int(parts[i + 1]) + i += 2 + else: + i += 1 + elif parts[i] == "seed": + if i + 1 < len(parts) and parts[i + 1].isdigit(): + params.lowrank_seed = int(parts[i + 1]) + i += 2 + else: + i += 1 + else: + # Skip unknown parameter + i += 1 + + return params def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module): @@ -18,49 +102,79 @@ def initialize_urae( rank: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + use_ipca: bool = False, + use_lowrank: bool = True, + lowrank_q: Optional[int] = None, + lowrank_niter: int = 4, + lowrank_seed: Optional[int] = None, ): org_module_device = org_module.weight.device org_module_weight_dtype = org_module.weight.data.dtype org_module_requires_grad = org_module.weight.requires_grad + dtype = dtype if dtype is not None else lora_down.weight.data.dtype device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) assert isinstance(device, torch.device), f"Invalid device type: {device}" weight = org_module.weight.data.to(device, dtype=torch.float32) - with torch.autocast(device.type): - # SVD decomposition - V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + if use_ipca: + # For URAE we need all components to get the "residual" ones + ipca = IncrementalPCA( + n_components=None, # Get all components + batch_size=1024, + lowrank=use_lowrank, + lowrank_q=lowrank_q if lowrank_q is not None else min(weight.shape), # Use full rank for accurate residuals + lowrank_niter=lowrank_niter, + lowrank_seed=lowrank_seed, + ) + ipca.fit(weight) - # For URAE, use the LAST/SMALLEST singular values and vectors (residual components) + # For URAE, use the LAST/SMALLEST singular values + total_rank = min(weight.shape[0], weight.shape[1]) + V_full = ipca.components_.T # [out_features, total_rank] + S_full = ipca.singular_values_ # [total_rank] + + # Get the smallest singular values and vectors + Vr = V_full[:, -rank:] # Last rank left singular vectors + Sr = S_full[-rank:] # Last rank singular values + Sr /= rank + + # To get Uhr (last rank right singular vectors), transform basis vectors + identity = torch.eye(weight.shape[1], device=weight.device) + Uhr_full = ipca.transform(identity).T # [total_rank, in_features] + Uhr = Uhr_full[-rank:] # Last rank right singular vectors + else: + # Standard SVD approach + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) Vr = V[:, -rank:] Sr = S[-rank:] Sr /= rank Uhr = Uh[-rank:, :] - # Create down and up matrices - down = torch.diag(torch.sqrt(Sr)) @ Uhr - up = Vr @ torch.diag(torch.sqrt(Sr)) + # Create down and up matrices + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up = Vr @ torch.diag(torch.sqrt(Sr)) - # Get expected shapes - expected_down_shape = lora_down.weight.shape - expected_up_shape = lora_up.weight.shape + # Get expected shapes + expected_down_shape = lora_down.weight.shape + expected_up_shape = lora_up.weight.shape - # Verify shapes match expected - if down.shape != expected_down_shape: - warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")) + # Verify shapes match expected + if down.shape != expected_down_shape: + warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")) - if up.shape != expected_up_shape: - warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) + if up.shape != expected_up_shape: + warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")) - # Assign to LoRA weights - lora_up.weight.data = up - lora_down.weight.data = down + # Assign to LoRA weights + lora_up.weight.data = up + lora_down.weight.data = down - # Optionally, subtract from original weight - weight = weight - scale * (up @ down) - org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype) - org_module.weight.requires_grad = org_module_requires_grad + # Optionally, subtract from original weight + weight = weight - scale * (up @ down) + org_module.weight.data = weight.to(org_module_device, dtype=org_module_weight_dtype) + org_module.weight.requires_grad = org_module_requires_grad # PiSSA: Principal Singular Values and Singular Vectors Adaptation @@ -72,24 +186,68 @@ def initialize_pissa( rank: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, + use_ipca: bool = False, + use_lowrank: bool = True, + lowrank_q: Optional[int] = None, + lowrank_niter: int = 4, + lowrank_seed: Optional[int] = None, ): org_module_device = org_module.weight.device org_module_weight_dtype = org_module.weight.data.dtype org_module_requires_grad = org_module.weight.requires_grad + dtype = dtype if dtype is not None else lora_down.weight.data.dtype device = device if device is not None else (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) assert isinstance(device, torch.device), f"Invalid device type: {device}" weight = org_module.weight.data.clone().to(device, dtype=torch.float32) with torch.no_grad(): - # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, - V, S, Uh = torch.linalg.svd(weight, full_matrices=False) - Vr = V[:, :rank] - Sr = S[:rank] - Sr /= rank - Uhr = Uh[:rank] + if use_ipca: + # Use Incremental PCA for large matrices + ipca = IncrementalPCA( + n_components=rank, + batch_size=1024, + lowrank=use_lowrank, + lowrank_q=lowrank_q if lowrank_q is not None else 2 * rank, + lowrank_niter=lowrank_niter, + lowrank_seed=lowrank_seed, + ) + ipca.fit(weight) + # Extract principal components and singular values + Vr = ipca.components_.T # [out_features, rank] + Sr = ipca.singular_values_ # [rank] + Sr /= rank + + # We need to get Uhr from transforming an identity matrix + identity = torch.eye(weight.shape[1], device=weight.device) + Uhr = ipca.transform(identity).T # [rank, in_features] + + elif use_lowrank: + # Use low-rank SVD approximation which is faster + seed_enabled = lowrank_seed is not None + q_value = lowrank_q if lowrank_q is not None else 2 * rank + + with torch.random.fork_rng(enabled=seed_enabled): + if seed_enabled: + torch.manual_seed(lowrank_seed) + U, S, V = torch.svd_lowrank(weight, q=q_value, niter=lowrank_niter) + + Vr = U[:, :rank] # First rank left singular vectors + Sr = S[:rank] # First rank singular values + Sr /= rank + Uhr = V[:rank] # First rank right singular vectors + + else: + # Standard SVD approach + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + Vr = V[:, :rank] + Sr = S[:rank] + Sr /= rank + Uhr = Uh[:rank] + + # Create down and up matrices down = torch.diag(torch.sqrt(Sr)) @ Uhr up = Vr @ torch.diag(torch.sqrt(Sr)) diff --git a/networks/lora_flux.py b/networks/lora_flux.py index b0c30ce7..e6780e21 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -7,6 +7,7 @@ # https://github.com/microsoft/LoRA/blob/main/loralib/layers.py # https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +from dataclasses import asdict import math import os from typing import Dict, List, Optional, Type, Union @@ -109,32 +110,36 @@ class LoRAModule(torch.nn.Module): device: device to run initialization computation on """ if self.split_dims is None: - if initialize == "urae": - initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = self.lora_up.weight.data.detach().clone() - self._org_lora_down = self.lora_down.weight.data.detach().clone() - elif initialize == "pissa": - initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = self.lora_up.weight.data.detach().clone() - self._org_lora_down = self.lora_down.weight.data.detach().clone() + if initialize is not None: + params = initialize_parse_opts(initialize) + if initialize[:4] == "urae": + initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = self.lora_up.weight.data.detach().clone() + self._org_lora_down = self.lora_down.weight.data.detach().clone() + elif initialize[:5] == "pissa": + initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim, device=device, **asdict(params)) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = self.lora_up.weight.data.detach().clone() + self._org_lora_down = self.lora_down.weight.data.detach().clone() else: initialize_lora(self.lora_down, self.lora_up) else: assert isinstance(self.lora_down, torch.nn.ModuleList) assert isinstance(self.lora_up, torch.nn.ModuleList) for lora_down, lora_up in zip(self.lora_down, self.lora_up): - if initialize == "urae": - initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = lora_up.weight.data.detach().clone() - self._org_lora_down = lora_down.weight.data.detach().clone() - elif initialize == "pissa": - initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device) - # Need to store the original weights so we can get a plain LoRA out - self._org_lora_up = lora_up.weight.data.detach().clone() - self._org_lora_down = lora_down.weight.data.detach().clone() + if initialize is not None: + params = initialize_parse_opts(initialize) + if initialize[:4] == "urae": + initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device, **asdict(params)) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = lora_up.weight.data.detach().clone() + self._org_lora_down = lora_down.weight.data.detach().clone() + elif initialize[:5] == "pissa": + initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim, device=device, **asdict(params)) + # Need to store the original weights so we can get a plain LoRA out + self._org_lora_up = lora_up.weight.data.detach().clone() + self._org_lora_down = lora_down.weight.data.detach().clone() else: initialize_lora(lora_down, lora_up) @@ -1305,7 +1310,8 @@ class LoRANetwork(torch.nn.Module): state_dict = self.state_dict() - if self.initialize in ['pissa']: + # Need to decompose the parameters into a LoRA format + if self.initialize is not None and (self.initialize[:5] == "pissa" or self.initialize[:4] == "urae"): loras: List[Union[LoRAModule, LoRAInfModule]] = self.text_encoder_loras + self.unet_loras def convert_pissa_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): # Calculate ΔW = A'B' - AB @@ -1325,14 +1331,49 @@ class LoRANetwork(torch.nn.Module): # These matrices can now be used as standard LoRA weights return new_up, new_down + def convert_urae_to_standard_lora(trained_up: Tensor, trained_down: Tensor, orig_up: Tensor, orig_down: Tensor, rank: int): + # Calculate ΔW = A'B' - AB + delta_w = (trained_up @ trained_down) - (orig_up @ orig_down) + + # We need to create new low-rank matrices that represent this delta + U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False) + + # For URAE, we want to focus on the smallest singular values + # Take the bottom rank*2 singular values (opposite of PiSSA which takes the top ones) + total_rank = len(S) + rank_to_use = min(rank * 2, total_rank) + + if rank_to_use < total_rank: + # Use the smallest singular values and vectors + selected_U = U[:, -rank_to_use:] + selected_S = S[-rank_to_use:] + selected_V = V[-rank_to_use:, :] + else: + # If we'd use all values, just use the standard approach but with a note + print("Warning: Requested rank is too large for URAE specialty, using all singular values") + selected_U = U + selected_S = S + selected_V = V + + # Create new LoRA matrices + new_up = selected_U @ torch.diag(torch.sqrt(selected_S)) + new_down = torch.diag(torch.sqrt(selected_S)) @ selected_V + + # These matrices can now be used as standard LoRA weights + return new_up, new_down + with torch.no_grad(): - progress = tqdm(total=len(loras), desc="Convert PiSSA") + progress = tqdm(total=len(loras), desc="Converting") for lora in loras: lora_up_key = f"{lora.lora_name}.lora_up.weight" lora_down_key = f"{lora.lora_name}.lora_down.weight" lora_up = state_dict[lora_up_key] lora_down = state_dict[lora_down_key] - up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) + if self.initialize[:4] == "urae": + up, down = convert_urae_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) + elif self.initialize[:5] == "pissa": + up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up.to(lora_up.device), lora._org_lora_down.to(lora_up.device), lora.lora_dim) + # TODO: Capture option if we should offload # offload to CPU state_dict[lora_up_key] = up.detach()