mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Add initialization URAE, PiSSA for flux
This commit is contained in:
@@ -6491,6 +6491,82 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||
# endregion
|
||||
|
||||
|
||||
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(lora_up.weight)
|
||||
|
||||
# URAE: Ultra-Resolution Adaptation with Ease
|
||||
def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None):
|
||||
weight_dtype = org_module.weight.data.dtype
|
||||
weight = org_module.weight.data.to(device="cuda", dtype=torch.float32)
|
||||
|
||||
# SVD decomposition
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
|
||||
Vr = V[:, -rank:]
|
||||
Sr = S[-rank:]
|
||||
Sr /= rank
|
||||
Uhr = Uh[-rank:, :]
|
||||
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
|
||||
# Verify shapes match expected
|
||||
if down.shape != expected_down_shape:
|
||||
print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")
|
||||
|
||||
# Assign to LoRA weights
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
# Optionally, subtract from original weight
|
||||
weight = weight - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(dtype=weight_dtype)
|
||||
|
||||
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
||||
def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None):
|
||||
weight_dtype = org_module.weight.data.dtype
|
||||
|
||||
weight = org_module.weight.data.to(device="cuda", dtype=torch.float32)
|
||||
|
||||
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, : rank]
|
||||
Sr = S[: rank]
|
||||
Sr /= rank
|
||||
Uhr = Uh[: rank]
|
||||
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up= Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
|
||||
# Verify shapes match expected or reshape appropriately
|
||||
if down.shape != expected_down_shape:
|
||||
print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")
|
||||
# Additional reshaping logic if needed
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")
|
||||
# Additional reshaping logic if needed
|
||||
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
weight = weight.data - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(dtype=weight_dtype)
|
||||
|
||||
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||
class collator_class:
|
||||
def __init__(self, epoch, step, dataset):
|
||||
|
||||
@@ -16,7 +16,7 @@ import numpy as np
|
||||
import torch
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
from library.train_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -44,6 +44,7 @@ class LoRAModule(torch.nn.Module):
|
||||
rank_dropout=None,
|
||||
module_dropout=None,
|
||||
split_dims: Optional[List[int]] = None,
|
||||
initialize: Optional[str]=None
|
||||
):
|
||||
"""
|
||||
if alpha == 0 or None, alpha is rank (no scaling).
|
||||
@@ -61,6 +62,16 @@ class LoRAModule(torch.nn.Module):
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
rank_factor = self.lora_dim
|
||||
if rank_stabilized:
|
||||
rank_factor = math.sqrt(rank_factor)
|
||||
self.scale = alpha / rank_factor
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
self.split_dims = split_dims
|
||||
|
||||
if split_dims is None:
|
||||
@@ -74,8 +85,12 @@ class LoRAModule(torch.nn.Module):
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(self.lora_up.weight)
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.lora_dim)
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
else:
|
||||
initialize_lora(self.lora_down, self.lora_up)
|
||||
else:
|
||||
# conv2d not supported
|
||||
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
||||
@@ -85,16 +100,13 @@ class LoRAModule(torch.nn.Module):
|
||||
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||
)
|
||||
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
for lora_down in self.lora_down:
|
||||
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
for lora_up in self.lora_up:
|
||||
torch.nn.init.zeros_(lora_up.weight)
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, lora_down, lora_up, self.lora_dim)
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
else:
|
||||
initialize_lora(lora_down, lora_up)
|
||||
|
||||
# same as microsoft's
|
||||
self.multiplier = multiplier
|
||||
@@ -420,6 +432,7 @@ def create_network(
|
||||
if split_qkv is not None:
|
||||
split_qkv = True if split_qkv == "True" else False
|
||||
|
||||
initialize = kwargs.get("initialize", None)
|
||||
# train T5XXL
|
||||
train_t5xxl = kwargs.get("train_t5xxl", False)
|
||||
if train_t5xxl is not None:
|
||||
@@ -449,6 +462,7 @@ def create_network(
|
||||
in_dims=in_dims,
|
||||
train_double_block_indices=train_double_block_indices,
|
||||
train_single_block_indices=train_single_block_indices,
|
||||
initialize=initialize,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@@ -561,6 +575,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
in_dims: Optional[List[int]] = None,
|
||||
train_double_block_indices: Optional[List[bool]] = None,
|
||||
train_single_block_indices: Optional[List[bool]] = None,
|
||||
initialize: Optional[str] = None,
|
||||
verbose: Optional[bool] = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -722,6 +737,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
rank_dropout=rank_dropout,
|
||||
module_dropout=module_dropout,
|
||||
split_dims=split_dims,
|
||||
initialize=initialize,
|
||||
)
|
||||
loras.append(lora)
|
||||
|
||||
@@ -740,8 +756,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}:")
|
||||
|
||||
if initialize is not None:
|
||||
logger.info(f"Initialize Text Encoder LoRA using {initialize}")
|
||||
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
logger.info(f"create LoRA for Text Encoder {index+1}: {len(text_encoder_loras)} modules.")
|
||||
logger.info(f"created {len(text_encoder_loras)} modules for Text Encoder {index+1}.")
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
|
||||
@@ -753,6 +772,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
elif self.train_blocks == "double":
|
||||
target_replace_modules = LoRANetwork.FLUX_TARGET_REPLACE_MODULE_DOUBLE
|
||||
|
||||
logger.info("create LoRA for FLUX")
|
||||
|
||||
if initialize is not None:
|
||||
logger.info(f"Initialize FLUX LoRA using {initialize}")
|
||||
|
||||
self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_replace_modules)
|
||||
|
||||
@@ -762,7 +786,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
loras, _ = create_modules(True, None, unet, None, filter=filter, default_dim=in_dim)
|
||||
self.unet_loras.extend(loras)
|
||||
|
||||
logger.info(f"create LoRA for FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"FLUX {self.train_blocks} blocks: {len(self.unet_loras)} modules.")
|
||||
|
||||
if verbose:
|
||||
for lora in self.unet_loras:
|
||||
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")
|
||||
|
||||
Reference in New Issue
Block a user