From 85928dd3b061d079461f97a40b039ac8b16d165c Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 24 Mar 2025 04:02:58 -0400 Subject: [PATCH] Add initialization URAE, PiSSA for flux --- library/train_util.py | 76 +++++++++++++++++++++++++++++++++++++++++++ networks/lora_flux.py | 55 ++++++++++++++++++++++--------- 2 files changed, 116 insertions(+), 15 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f591c42..6643b5bf 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -6491,6 +6491,82 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # endregion +def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module): + torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(lora_up.weight) + +# URAE: Ultra-Resolution Adaptation with Ease +def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None): + weight_dtype = org_module.weight.data.dtype + weight = org_module.weight.data.to(device="cuda", dtype=torch.float32) + + # SVD decomposition + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + + # For URAE, use the LAST/SMALLEST singular values and vectors (residual components) + Vr = V[:, -rank:] + Sr = S[-rank:] + Sr /= rank + Uhr = Uh[-rank:, :] + + # Create down and up matrices + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up = Vr @ torch.diag(torch.sqrt(Sr)) + + # Get expected shapes + expected_down_shape = lora_down.weight.shape + expected_up_shape = lora_up.weight.shape + + # Verify shapes match expected + if down.shape != expected_down_shape: + print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}") + + if up.shape != expected_up_shape: + print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}") + + # Assign to LoRA weights + lora_up.weight.data = up + lora_down.weight.data = down + + # Optionally, subtract from original weight + weight = weight - scale * (up @ down) + org_module.weight.data = weight.to(dtype=weight_dtype) + +# PiSSA: Principal Singular Values and Singular Vectors Adaptation +def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None): + weight_dtype = org_module.weight.data.dtype + + weight = org_module.weight.data.to(device="cuda", dtype=torch.float32) + + # USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel}, + V, S, Uh = torch.linalg.svd(weight, full_matrices=False) + Vr = V[:, : rank] + Sr = S[: rank] + Sr /= rank + Uhr = Uh[: rank] + + down = torch.diag(torch.sqrt(Sr)) @ Uhr + up= Vr @ torch.diag(torch.sqrt(Sr)) + + # Get expected shapes + expected_down_shape = lora_down.weight.shape + expected_up_shape = lora_up.weight.shape + + # Verify shapes match expected or reshape appropriately + if down.shape != expected_down_shape: + print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}") + # Additional reshaping logic if needed + + if up.shape != expected_up_shape: + print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}") + # Additional reshaping logic if needed + + lora_up.weight.data = up + lora_down.weight.data = down + + weight = weight.data - scale * (up @ down) + org_module.weight.data = weight.to(dtype=weight_dtype) + # collate_fn用 epoch,stepはmultiprocessing.Value class collator_class: def __init__(self, epoch, step, dataset): diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 91e9cd77..3c405c7b 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -16,7 +16,7 @@ import numpy as np import torch import re from library.utils import setup_logging -from library.sdxl_original_unet import SdxlUNet2DConditionModel +from library.train_util import initialize_lora, initialize_pissa, initialize_urae setup_logging() import logging @@ -44,6 +44,7 @@ class LoRAModule(torch.nn.Module): rank_dropout=None, module_dropout=None, split_dims: Optional[List[int]] = None, + initialize: Optional[str]=None ): """ if alpha == 0 or None, alpha is rank (no scaling). @@ -61,6 +62,16 @@ class LoRAModule(torch.nn.Module): out_dim = org_module.out_features self.lora_dim = lora_dim + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + rank_factor = self.lora_dim + if rank_stabilized: + rank_factor = math.sqrt(rank_factor) + self.scale = alpha / rank_factor + self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + self.split_dims = split_dims if split_dims is None: @@ -74,8 +85,12 @@ class LoRAModule(torch.nn.Module): self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) + if initialize == "urae": + initialize_urae(org_module, self.lora_down, self.lora_up, self.lora_dim) + elif initialize == "pissa": + initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim) + else: + initialize_lora(self.lora_down, self.lora_up) else: # conv2d not supported assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim" @@ -85,16 +100,13 @@ class LoRAModule(torch.nn.Module): [torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))] ) self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims]) - for lora_down in self.lora_down: - torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5)) - for lora_up in self.lora_up: - torch.nn.init.zeros_(lora_up.weight) - - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - alpha = self.lora_dim if alpha is None or alpha == 0 else alpha - self.scale = alpha / self.lora_dim - self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える + for lora_down, lora_up in zip(self.lora_down, self.lora_up): + if initialize == "urae": + initialize_urae(org_module, lora_down, lora_up, self.lora_dim) + elif initialize == "pissa": + initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim) + else: + initialize_lora(lora_down, lora_up) # same as microsoft's self.multiplier = multiplier @@ -420,6 +432,7 @@ def create_network( if split_qkv is not None: split_qkv = True if split_qkv == "True" else False + initialize = kwargs.get("initialize", None) # train T5XXL train_t5xxl = kwargs.get("train_t5xxl", False) if train_t5xxl is not None: @@ -449,6 +462,7 @@ def create_network( in_dims=in_dims, train_double_block_indices=train_double_block_indices, train_single_block_indices=train_single_block_indices, + initialize=initialize, verbose=verbose, ) @@ -561,6 +575,7 @@ class LoRANetwork(torch.nn.Module): in_dims: Optional[List[int]] = None, train_double_block_indices: Optional[List[bool]] = None, train_single_block_indices: Optional[List[bool]] = None, + initialize: Optional[str] = None, verbose: Optional[bool] = False, ) -> None: super().__init__() @@ -722,6 +737,7 @@ class LoRANetwork(torch.nn.Module): rank_dropout=rank_dropout, module_dropout=module_dropout, split_dims=split_dims, + initialize=initialize, ) loras.append(lora) @@ -740,8 +756,11 @@ class LoRANetwork(torch.nn.Module): logger.info(f"create LoRA for Text Encoder {index+1}:") + if initialize is not None: + logger.info(f"Initialize Text Encoder LoRA using {initialize}") + text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) - logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.") + logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index+1}.") self.text_encoder_loras.extend(text_encoder_loras) skipped_te += skipped @@ -753,6 +772,11 @@ class LoRANetwork(torch.nn.Module): elif self.train_blocks == "double": target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE + logger.info("create LoRA for FLUX") + + if initialize is not None: + logger.info(f"Initialize FLUX LoRA using {initialize}") + self.unet_loras: List[Union[LoRAModule, LoRAInfModule]] self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules) @@ -762,7 +786,8 @@ class LoRANetwork(torch.nn.Module): loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim) self.unet_loras.extend(loras) - logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + logger.info(f"FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.") + if verbose: for lora in self.unet_loras: logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")