mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Update initialization, add lora_util, add tests
This commit is contained in:
87
library/lora_util.py
Normal file
87
library/lora_util.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(lora_up.weight)
|
||||
|
||||
# URAE: Ultra-Resolution Adaptation with Ease
|
||||
def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
|
||||
device = device if device is not None else lora_down.weight.data.device
|
||||
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
||||
|
||||
weight = org_module.weight.data.to(device, dtype=torch.float32)
|
||||
|
||||
with torch.autocast(device.type, dtype=torch.float32):
|
||||
# SVD decomposition
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
|
||||
Vr = V[:, -rank:]
|
||||
Sr = S[-rank:]
|
||||
Sr /= rank
|
||||
Uhr = Uh[-rank:, :]
|
||||
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
|
||||
# Verify shapes match expected
|
||||
if down.shape != expected_down_shape:
|
||||
warnings.warn(UserWarning(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
warnings.warn(UserWarning(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
|
||||
# Assign to LoRA weights
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
# Optionally, subtract from original weight
|
||||
weight = weight - scale * (up @ down)
|
||||
|
||||
weight_dtype = org_module.weight.data.dtype
|
||||
org_module.weight.data = weight.to(dtype=weight_dtype)
|
||||
|
||||
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
||||
def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device: Optional[torch.device]=None, dtype: Optional[torch.dtype]=None):
|
||||
weight_dtype = org_module.weight.data.dtype
|
||||
|
||||
device = device if device is not None else lora_down.weight.data.device
|
||||
assert isinstance(device, torch.device), f"Invalid device type: {device}"
|
||||
|
||||
weight = org_module.weight.data.clone().to(device, dtype=torch.float32)
|
||||
|
||||
with torch.autocast(device.type, dtype=torch.float32):
|
||||
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, : rank]
|
||||
Sr = S[: rank]
|
||||
Sr /= rank
|
||||
Uhr = Uh[: rank]
|
||||
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
|
||||
# Verify shapes match expected or reshape appropriately
|
||||
if down.shape != expected_down_shape:
|
||||
warnings.warn(UserWarning(f"Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}"))
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
warnings.warn(UserWarning(f"Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}"))
|
||||
|
||||
lora_up.weight.data = up.to(dtype=lora_up.weight.dtype)
|
||||
lora_down.weight.data = down.to(dtype=lora_up.weight.dtype)
|
||||
|
||||
weight = weight.data - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(dtype=weight_dtype)
|
||||
@@ -1,5 +1,4 @@
|
||||
# common functions for training
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
@@ -6490,83 +6489,6 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def initialize_lora(lora_down: torch.nn.Module, lora_up: torch.nn.Module):
|
||||
torch.nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
|
||||
torch.nn.init.zeros_(lora_up.weight)
|
||||
|
||||
# URAE: Ultra-Resolution Adaptation with Ease
|
||||
def initialize_urae(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None):
|
||||
weight_dtype = org_module.weight.data.dtype
|
||||
weight = org_module.weight.data.to(device="cuda", dtype=torch.float32)
|
||||
|
||||
# SVD decomposition
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
|
||||
# For URAE, use the LAST/SMALLEST singular values and vectors (residual components)
|
||||
Vr = V[:, -rank:]
|
||||
Sr = S[-rank:]
|
||||
Sr /= rank
|
||||
Uhr = Uh[-rank:, :]
|
||||
|
||||
# Create down and up matrices
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up = Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
|
||||
# Verify shapes match expected
|
||||
if down.shape != expected_down_shape:
|
||||
print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")
|
||||
|
||||
# Assign to LoRA weights
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
# Optionally, subtract from original weight
|
||||
weight = weight - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(dtype=weight_dtype)
|
||||
|
||||
# PiSSA: Principal Singular Values and Singular Vectors Adaptation
|
||||
def initialize_pissa(org_module: torch.nn.Module, lora_down: torch.nn.Module, lora_up: torch.nn.Module, scale: float, rank: int, device=None, dtype=None):
|
||||
weight_dtype = org_module.weight.data.dtype
|
||||
|
||||
weight = org_module.weight.data.to(device="cuda", dtype=torch.float32)
|
||||
|
||||
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
|
||||
V, S, Uh = torch.linalg.svd(weight, full_matrices=False)
|
||||
Vr = V[:, : rank]
|
||||
Sr = S[: rank]
|
||||
Sr /= rank
|
||||
Uhr = Uh[: rank]
|
||||
|
||||
down = torch.diag(torch.sqrt(Sr)) @ Uhr
|
||||
up= Vr @ torch.diag(torch.sqrt(Sr))
|
||||
|
||||
# Get expected shapes
|
||||
expected_down_shape = lora_down.weight.shape
|
||||
expected_up_shape = lora_up.weight.shape
|
||||
|
||||
# Verify shapes match expected or reshape appropriately
|
||||
if down.shape != expected_down_shape:
|
||||
print(f"Warning: Down matrix shape mismatch. Got {down.shape}, expected {expected_down_shape}")
|
||||
# Additional reshaping logic if needed
|
||||
|
||||
if up.shape != expected_up_shape:
|
||||
print(f"Warning: Up matrix shape mismatch. Got {up.shape}, expected {expected_up_shape}")
|
||||
# Additional reshaping logic if needed
|
||||
|
||||
lora_up.weight.data = up
|
||||
lora_down.weight.data = down
|
||||
|
||||
weight = weight.data - scale * (up @ down)
|
||||
org_module.weight.data = weight.to(dtype=weight_dtype)
|
||||
|
||||
# collate_fn用 epoch,stepはmultiprocessing.Value
|
||||
class collator_class:
|
||||
def __init__(self, epoch, step, dataset):
|
||||
|
||||
@@ -18,7 +18,7 @@ from torch import Tensor
|
||||
from tqdm import tqdm
|
||||
import re
|
||||
from library.utils import setup_logging
|
||||
from library.train_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
from library.lora_util import initialize_lora, initialize_pissa, initialize_urae
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
@@ -38,7 +38,7 @@ class LoRAModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
lora_name,
|
||||
org_module: torch.nn.Module,
|
||||
org_module: torch.nn.Linear,
|
||||
multiplier=1.0,
|
||||
lora_dim=4,
|
||||
alpha=1,
|
||||
@@ -56,68 +56,32 @@ class LoRAModule(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.lora_name = lora_name
|
||||
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
in_dim = org_module.in_channels
|
||||
out_dim = org_module.out_channels
|
||||
else:
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
in_dim = org_module.in_features
|
||||
out_dim = org_module.out_features
|
||||
|
||||
self.lora_dim = lora_dim
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha = alpha.detach().float().item() # without casting, bf16 causes error
|
||||
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
|
||||
self.scale = alpha / self.lora_dim
|
||||
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
||||
|
||||
self.split_dims = split_dims
|
||||
self.initialize = initialize
|
||||
|
||||
if split_dims is None:
|
||||
if org_module.__class__.__name__ == "Conv2d":
|
||||
kernel_size = org_module.kernel_size
|
||||
stride = org_module.stride
|
||||
padding = org_module.padding
|
||||
self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
|
||||
self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
||||
else:
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(self.lora_down, self.lora_up)
|
||||
self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False)
|
||||
self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False)
|
||||
else:
|
||||
# conv2d not supported
|
||||
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
|
||||
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
|
||||
# print(f"split_dims: {split_dims}")
|
||||
self.lora_down = torch.nn.ModuleList(
|
||||
[torch.nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
|
||||
)
|
||||
self.lora_up = torch.nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])
|
||||
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
|
||||
if initialize == "urae":
|
||||
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
elif initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(lora_down, lora_up)
|
||||
|
||||
with torch.autocast(org_module.weight.device.type), torch.no_grad():
|
||||
self.initialize_weights(org_module)
|
||||
|
||||
# same as microsoft's
|
||||
self.multiplier = multiplier
|
||||
@@ -126,12 +90,44 @@ class LoRAModule(torch.nn.Module):
|
||||
self.rank_dropout = rank_dropout
|
||||
self.module_dropout = module_dropout
|
||||
|
||||
|
||||
def initialize_weights(self, org_module: torch.nn.Module):
|
||||
if self.split_dims is None:
|
||||
if self.initialize == "urae":
|
||||
initialize_urae(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
elif self.initialize == "pissa":
|
||||
initialize_pissa(org_module, self.lora_down, self.lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = self.lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = self.lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(self.lora_down, self.lora_up)
|
||||
else:
|
||||
assert isinstance(self.lora_down, torch.nn.ModuleList)
|
||||
assert isinstance(self.lora_up, torch.nn.ModuleList)
|
||||
for lora_down, lora_up in zip(self.lora_down, self.lora_up):
|
||||
if self.initialize == "urae":
|
||||
initialize_urae(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
elif self.initialize == "pissa":
|
||||
initialize_pissa(org_module, lora_down, lora_up, self.scale, self.lora_dim)
|
||||
# Need to store the original weights so we can get a plain LoRA out
|
||||
self._org_lora_up = lora_up.weight.data.detach().clone()
|
||||
self._org_lora_down = lora_down.weight.data.detach().clone()
|
||||
else:
|
||||
initialize_lora(lora_down, lora_up)
|
||||
|
||||
def apply_to(self):
|
||||
self.org_forward = self.org_module.forward
|
||||
self.org_module.forward = self.forward
|
||||
del self.org_module
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
org_forwarded = self.org_forward(x)
|
||||
|
||||
# module dropout
|
||||
@@ -175,10 +171,6 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.rank_dropout is not None and self.training:
|
||||
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
|
||||
for i in range(len(lxs)):
|
||||
if len(lx.size()) == 3:
|
||||
masks[i] = masks[i].unsqueeze(1)
|
||||
elif len(lx.size()) == 4:
|
||||
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
|
||||
lxs[i] = lxs[i] * masks[i]
|
||||
|
||||
# scaling for rank dropout: treat as if the rank is changed
|
||||
@@ -765,6 +757,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
# 毎回すべてのモジュールを作るのは無駄なので要検討
|
||||
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
|
||||
skipped_te = []
|
||||
text_encoders = text_encoders if isinstance(text_encoders, list) else [text_encoders]
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
index = i
|
||||
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
|
||||
@@ -1103,8 +1096,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
delta_w = (trained_up @ trained_down) - (orig_up @ orig_down)
|
||||
|
||||
# We need to create new low-rank matrices that represent this delta
|
||||
# One approach is to do SVD on delta_w
|
||||
U, S, V = torch.linalg.svd(delta_w, full_matrices=False)
|
||||
U, S, V = torch.linalg.svd(delta_w.to(device="cuda", dtype=torch.float32), full_matrices=False)
|
||||
|
||||
# Take the top 2*r singular values (as suggested in the paper)
|
||||
rank = rank * 2
|
||||
@@ -1124,7 +1116,8 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_down_key = f"{lora.lora_name}.lora_down.weight"
|
||||
lora_up = state_dict[lora_up_key]
|
||||
lora_down = state_dict[lora_down_key]
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
|
||||
with torch.autocast("cuda"):
|
||||
up, down = convert_pissa_to_standard_lora(lora_up, lora_down, lora._org_lora_up, lora._org_lora_down, lora.lora_dim)
|
||||
state_dict[lora_up_key] = up.detach()
|
||||
state_dict[lora_down_key] = down.detach()
|
||||
progress.update(1)
|
||||
|
||||
216
tests/library/test_lora_util.py
Normal file
216
tests/library/test_lora_util.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import torch
|
||||
import pytest
|
||||
from library.lora_util import initialize_pissa
|
||||
from tests.test_util import generate_synthetic_weights
|
||||
|
||||
|
||||
def test_initialize_pissa_basic():
|
||||
# Create a simple linear layer
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Create LoRA layers with matching shapes
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
# Store original weight for comparison
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
# Call the initialization function
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify basic properties
|
||||
assert lora_down.weight.data is not None
|
||||
assert lora_up.weight.data is not None
|
||||
assert org_module.weight.data is not None
|
||||
|
||||
# Check that the weights have been modified
|
||||
assert not torch.equal(original_weight, org_module.weight.data)
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_constraints():
|
||||
# Test with different rank values
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
torch.nn.init.xavier_uniform_(org_module.weight)
|
||||
torch.nn.init.zeros_(org_module.bias)
|
||||
|
||||
# Test with rank less than min dimension
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Test with rank equal to min dimension
|
||||
lora_down = torch.nn.Linear(20, 10)
|
||||
lora_up = torch.nn.Linear(10, 10)
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=10)
|
||||
|
||||
|
||||
def test_initialize_pissa_shape_mismatch():
|
||||
# Test with shape mismatch to ensure warning is printed
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
|
||||
# Intentionally mismatched shapes to test warning mechanism
|
||||
lora_down = torch.nn.Linear(20, 5) # Different shape
|
||||
lora_up = torch.nn.Linear(3, 15) # Different shape
|
||||
|
||||
with pytest.warns(UserWarning):
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
|
||||
def test_initialize_pissa_scaling():
|
||||
# Test different scaling factors
|
||||
scales = [0.0, 0.1, 1.0]
|
||||
|
||||
for scale in scales:
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
|
||||
# Check that the weight modification follows the scaling
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
expected_diff = scale * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
torch.testing.assert_close(weight_diff, expected_diff, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype():
|
||||
# Test with different data types
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
# Verify output dtype matches input
|
||||
assert org_module.weight.dtype == dtype
|
||||
|
||||
|
||||
def test_initialize_pissa_svd_properties():
|
||||
# Verify SVD decomposition properties
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
lora_down = torch.nn.Linear(20, 3)
|
||||
lora_up = torch.nn.Linear(3, 10)
|
||||
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
|
||||
# Reconstruct the weight
|
||||
reconstructed_weight = original_weight - 0.1 * (lora_up.weight.data @ lora_down.weight.data)
|
||||
|
||||
# Check reconstruction is close to original
|
||||
torch.testing.assert_close(reconstructed_weight, org_module.weight.data, rtol=1e-4, atol=1e-4)
|
||||
|
||||
|
||||
def test_initialize_pissa_device_handling():
|
||||
# Test different device scenarios
|
||||
devices = [torch.device("cpu"), torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")]
|
||||
|
||||
for device in devices:
|
||||
# Create modules on specific device
|
||||
org_module = torch.nn.Linear(10, 5).to(device)
|
||||
lora_down = torch.nn.Linear(10, 2).to(device)
|
||||
lora_up = torch.nn.Linear(2, 5).to(device)
|
||||
|
||||
# Test initialization with explicit device
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2, device=device)
|
||||
|
||||
# Verify modules are on the correct device
|
||||
assert org_module.weight.data.device.type == device.type
|
||||
assert lora_down.weight.data.device.type == device.type
|
||||
assert lora_up.weight.data.device.type == device.type
|
||||
|
||||
|
||||
def test_initialize_pissa_dtype_preservation():
|
||||
# Test dtype preservation and conversion
|
||||
dtypes = [torch.float16, torch.float32, torch.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
org_module = torch.nn.Linear(10, 5).to(dtype=dtype)
|
||||
lora_down = torch.nn.Linear(10, 2).to(dtype=dtype)
|
||||
lora_up = torch.nn.Linear(2, 5).to(dtype=dtype)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=2)
|
||||
|
||||
assert org_module.weight.dtype == dtype
|
||||
assert lora_down.weight.dtype == dtype
|
||||
assert lora_up.weight.dtype == dtype
|
||||
|
||||
|
||||
def test_initialize_pissa_rank_limits():
|
||||
# Test rank limits
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
|
||||
# Test minimum rank (should work)
|
||||
lora_down_min = torch.nn.Linear(10, 1)
|
||||
lora_up_min = torch.nn.Linear(1, 5)
|
||||
initialize_pissa(org_module, lora_down_min, lora_up_min, scale=0.1, rank=1)
|
||||
|
||||
# Test maximum rank (rank = min(input_dim, output_dim))
|
||||
max_rank = min(10, 5)
|
||||
lora_down_max = torch.nn.Linear(10, max_rank)
|
||||
lora_up_max = torch.nn.Linear(max_rank, 5)
|
||||
initialize_pissa(org_module, lora_down_max, lora_up_max, scale=0.1, rank=max_rank)
|
||||
|
||||
|
||||
def test_initialize_pissa_numerical_stability():
|
||||
# Test with numerically challenging scenarios
|
||||
scenarios = [
|
||||
torch.randn(20, 10) * 1e-10, # Very small values
|
||||
torch.randn(20, 10) * 1e10, # Very large values
|
||||
torch.zeros(20, 10), # Zero matrix
|
||||
]
|
||||
|
||||
for i, weight_matrix in enumerate(scenarios):
|
||||
org_module = torch.nn.Linear(20, 10)
|
||||
org_module.weight.data = weight_matrix
|
||||
|
||||
lora_down = torch.nn.Linear(10, 3)
|
||||
lora_up = torch.nn.Linear(3, 20)
|
||||
|
||||
try:
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=0.1, rank=3)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Initialization failed for scenario ({i}): {e}")
|
||||
|
||||
|
||||
def test_initialize_pissa_scale_effects():
|
||||
# Test different scaling factors
|
||||
org_module = torch.nn.Linear(10, 5)
|
||||
original_weight = org_module.weight.data.clone()
|
||||
|
||||
test_scales = [0.0, 0.1, 0.5, 1.0]
|
||||
|
||||
for scale in test_scales:
|
||||
# Reset module for each test
|
||||
org_module.weight.data = original_weight.clone()
|
||||
|
||||
lora_down = torch.nn.Linear(10, 2)
|
||||
lora_up = torch.nn.Linear(2, 5)
|
||||
|
||||
initialize_pissa(org_module, lora_down, lora_up, scale=scale, rank=2)
|
||||
|
||||
# Verify weight modification proportional to scale
|
||||
weight_diff = original_weight - org_module.weight.data
|
||||
|
||||
# Approximate check of scaling effect
|
||||
if scale == 0.0:
|
||||
torch.testing.assert_close(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
else:
|
||||
assert not torch.allclose(weight_diff, torch.zeros_like(weight_diff), rtol=1e-4, atol=1e-6)
|
||||
609
tests/networks/test_lora_flux.py
Normal file
609
tests/networks/test_lora_flux.py
Normal file
@@ -0,0 +1,609 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from networks.lora_flux import LoRAModule, LoRANetwork, create_network
|
||||
from tests.test_util import generate_synthetic_weights
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def test_basic_linear_module_initialization():
|
||||
# Test basic Linear module initialization
|
||||
org_module = nn.Linear(10, 20)
|
||||
lora_module = LoRAModule(lora_name="test_linear", org_module=org_module, lora_dim=4)
|
||||
|
||||
# Check basic attributes
|
||||
assert lora_module.lora_name == "test_linear"
|
||||
assert lora_module.lora_dim == 4
|
||||
|
||||
# Check LoRA layers
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
# Check input and output dimensions
|
||||
assert lora_module.lora_down.in_features == 10
|
||||
assert lora_module.lora_down.out_features == 4
|
||||
assert lora_module.lora_up.in_features == 4
|
||||
assert lora_module.lora_up.out_features == 20
|
||||
|
||||
|
||||
def test_split_dims_initialization():
|
||||
# Test initialization with split_dims
|
||||
org_module = nn.Linear(10, 15)
|
||||
lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, split_dims=[5, 5, 5])
|
||||
|
||||
# Check split_dims specific attributes
|
||||
assert lora_module.split_dims == [5, 5, 5]
|
||||
assert isinstance(lora_module.lora_down, nn.ModuleList)
|
||||
assert isinstance(lora_module.lora_up, nn.ModuleList)
|
||||
|
||||
# Check number of split modules
|
||||
assert len(lora_module.lora_down) == 3
|
||||
assert len(lora_module.lora_up) == 3
|
||||
|
||||
# Check dimensions of split modules
|
||||
for down, up in zip(lora_module.lora_down, lora_module.lora_up):
|
||||
assert down.in_features == 10
|
||||
assert down.out_features == 4
|
||||
assert up.in_features == 4
|
||||
assert up.out_features in [5, 5, 5]
|
||||
|
||||
|
||||
def test_alpha_scaling():
|
||||
# Test alpha scaling
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
# Default alpha (should be equal to lora_dim)
|
||||
lora_module1 = LoRAModule(lora_name="test_alpha1", org_module=org_module, lora_dim=4, alpha=0)
|
||||
assert lora_module1.scale == 1.0
|
||||
|
||||
# Custom alpha
|
||||
lora_module2 = LoRAModule(lora_name="test_alpha2", org_module=org_module, lora_dim=4, alpha=2)
|
||||
assert lora_module2.scale == 0.5
|
||||
|
||||
|
||||
def test_initialization_methods():
|
||||
# Test different initialization methods
|
||||
org_module = nn.Linear(10, 20)
|
||||
org_module.weight.data = generate_synthetic_weights(org_module.weight)
|
||||
|
||||
# Default initialization
|
||||
lora_module1 = LoRAModule(lora_name="test_init_default", org_module=org_module, lora_dim=4)
|
||||
|
||||
assert lora_module1.lora_down.weight.shape == (4, 10)
|
||||
assert lora_module1.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# URAE initialization
|
||||
lora_module2 = LoRAModule(lora_name="test_init_urae", org_module=org_module, lora_dim=4, initialize="urae")
|
||||
assert hasattr(lora_module2, "_org_lora_up") and lora_module2._org_lora_down is not None
|
||||
assert hasattr(lora_module2, "_org_lora_down") and lora_module2._org_lora_down is not None
|
||||
|
||||
assert lora_module2.lora_down.weight.shape == (4, 10)
|
||||
assert lora_module2.lora_up.weight.shape == (20, 4)
|
||||
|
||||
# PISSA initialization
|
||||
lora_module3 = LoRAModule(lora_name="test_init_pissa", org_module=org_module, lora_dim=4, initialize="pissa")
|
||||
assert hasattr(lora_module3, "_org_lora_up") and lora_module3._org_lora_down is not None
|
||||
assert hasattr(lora_module3, "_org_lora_down") and lora_module3._org_lora_down is not None
|
||||
|
||||
assert lora_module3.lora_down.weight.shape == (4, 10)
|
||||
assert lora_module3.lora_up.weight.shape == (20, 4)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_forward_basic_linear():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
org_module.weight.data = torch.testing.make_tensor(
|
||||
org_module.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
lora_module = LoRAModule(lora_name="test_forward", org_module=org_module, lora_dim=4, alpha=4, multiplier=1.0)
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
lora_module.lora_down.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
lora_module.lora_up.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Perform forward pass
|
||||
output = lora_module.forward(x)
|
||||
|
||||
# Structural assertions
|
||||
assert output is not None, "Output should not be None"
|
||||
assert isinstance(output, torch.Tensor), "Output should be a torch.Tensor"
|
||||
|
||||
# Shape assertions
|
||||
assert output.shape == (5, 20), "Output shape should match expected dimensions"
|
||||
|
||||
# Type and device assertions
|
||||
assert output.dtype == torch.float32, "Output should be float32"
|
||||
assert output.device == x.device, "Output should be on the same device as input"
|
||||
|
||||
|
||||
def test_forward_module_dropout():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
lora_module = LoRAModule(
|
||||
lora_name="test_module_dropout",
|
||||
org_module=org_module,
|
||||
lora_dim=4,
|
||||
multiplier=1.0,
|
||||
module_dropout=1.0, # Always drop
|
||||
)
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Enable training mode
|
||||
lora_module.train()
|
||||
|
||||
# Perform forward pass
|
||||
output = lora_module.forward(x)
|
||||
|
||||
# Check if output is same as original module output
|
||||
org_output = org_module(x)
|
||||
torch.testing.assert_close(output, org_output)
|
||||
|
||||
|
||||
def test_forward_rank_dropout():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
lora_module = LoRAModule(
|
||||
lora_name="test_rank_dropout",
|
||||
org_module=org_module,
|
||||
lora_dim=4,
|
||||
multiplier=1.0,
|
||||
rank_dropout=0.5, # 50% dropout
|
||||
)
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
# Make lora weights predictable
|
||||
lora_module.lora_down.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
lora_module.lora_up.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Enable training mode
|
||||
lora_module.train()
|
||||
|
||||
# Perform multiple forward passes to show dropout effect
|
||||
outputs = [lora_module.forward(x) for _ in range(10)]
|
||||
|
||||
# Check that outputs are not all identical due to rank dropout
|
||||
differences = [
|
||||
torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs))
|
||||
]
|
||||
assert not all(differences)
|
||||
|
||||
|
||||
def test_forward_split_dims():
|
||||
# Create a basic linear module with split dimensions
|
||||
org_module = nn.Linear(10, 15)
|
||||
|
||||
lora_module = LoRAModule(lora_name="test_split_dims", org_module=org_module, lora_dim=4, multiplier=1.0, split_dims=[5, 5, 5])
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.ModuleList)
|
||||
assert isinstance(lora_module.lora_up, nn.ModuleList)
|
||||
|
||||
# Make lora weights predictable
|
||||
for down in lora_module.lora_down:
|
||||
assert isinstance(down, nn.Linear)
|
||||
down.weight.data = torch.testing.make_tensor(down.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0)
|
||||
for up in lora_module.lora_up:
|
||||
assert isinstance(up, nn.Linear)
|
||||
up.weight.data = torch.testing.make_tensor(up.weight.data.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Perform forward pass
|
||||
output = lora_module.forward(x)
|
||||
|
||||
# Check output dimensions
|
||||
assert output.shape == (5, 15)
|
||||
|
||||
|
||||
def test_forward_dropout():
|
||||
# Create a basic linear module
|
||||
org_module = nn.Linear(10, 20)
|
||||
|
||||
lora_module = LoRAModule(
|
||||
lora_name="test_dropout",
|
||||
org_module=org_module,
|
||||
lora_dim=4,
|
||||
multiplier=1.0,
|
||||
dropout=0.5, # 50% dropout
|
||||
)
|
||||
|
||||
lora_module.apply_to()
|
||||
|
||||
assert isinstance(lora_module.lora_down, nn.Linear)
|
||||
assert isinstance(lora_module.lora_up, nn.Linear)
|
||||
|
||||
# Make lora weights predictable
|
||||
lora_module.lora_down.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_down.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
lora_module.lora_up.weight.data = torch.testing.make_tensor(
|
||||
lora_module.lora_up.weight.shape, dtype=torch.float32, device="cpu", low=0.1, high=1.0
|
||||
)
|
||||
|
||||
# Create input
|
||||
x = torch.ones(5, 10)
|
||||
|
||||
# Enable training mode
|
||||
lora_module.train()
|
||||
|
||||
# Perform multiple forward passes to show dropout effect
|
||||
outputs = [lora_module.forward(x) for _ in range(10)]
|
||||
|
||||
# Check that outputs are not all identical due to dropout
|
||||
differences = [
|
||||
torch.all(torch.eq(outputs[i], outputs[j])).item() for i in range(len(outputs)) for j in range(i + 1, len(outputs))
|
||||
]
|
||||
assert not all(differences)
|
||||
|
||||
|
||||
def test_create_network_default_parameters(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Call the function with minimal parameters
|
||||
network = create_network(
|
||||
multiplier=1.0, network_dim=None, network_alpha=None, ae=mock_ae, text_encoders=mock_text_encoders, flux=mock_flux
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert network is not None
|
||||
assert network.multiplier == 1.0
|
||||
assert network.lora_dim == 4 # default network_dim
|
||||
assert network.alpha == 1.0 # default network_alpha
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_text_encoder():
|
||||
class CLIPAttention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a CLIPAttention
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
|
||||
|
||||
class MockTextEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a CLIPTextModel
|
||||
self.attns = torch.nn.ModuleList([CLIPAttention() for _ in range(3)])
|
||||
|
||||
return MockTextEncoder()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flux():
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a DoubleStreamBlock
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a SingleStreamBlock
|
||||
self.layers = torch.nn.ModuleList([torch.nn.Linear(10, 15) for _ in range(3)])
|
||||
|
||||
class MockFlux(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Add some dummy layers to simulate a Flux
|
||||
self.double_blocks = torch.nn.ModuleList([DoubleStreamBlock() for _ in range(3)])
|
||||
self.single_blocks = torch.nn.ModuleList([SingleStreamBlock() for _ in range(3)])
|
||||
|
||||
return MockFlux()
|
||||
|
||||
|
||||
def test_create_network_custom_parameters(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Prepare custom parameters
|
||||
custom_params = {
|
||||
"conv_dim": 8,
|
||||
"conv_alpha": 0.5,
|
||||
"img_attn_dim": 16,
|
||||
"txt_attn_dim": 16,
|
||||
"neuron_dropout": 0.1,
|
||||
"rank_dropout": 0.2,
|
||||
"module_dropout": 0.3,
|
||||
"train_blocks": "double",
|
||||
"split_qkv": "True",
|
||||
"train_t5xxl": "True",
|
||||
"in_dims": "[64, 32, 16, 8, 4]",
|
||||
"verbose": "True",
|
||||
}
|
||||
|
||||
# Call the function with custom parameters
|
||||
network = create_network(
|
||||
multiplier=1.5,
|
||||
network_dim=8,
|
||||
network_alpha=2.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
**custom_params,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert network is not None
|
||||
assert network.multiplier == 1.5
|
||||
assert network.lora_dim == 8
|
||||
assert network.alpha == 2.0
|
||||
assert network.conv_lora_dim == 8
|
||||
assert network.conv_alpha == 0.5
|
||||
assert network.train_blocks == "double"
|
||||
assert network.split_qkv is True
|
||||
assert network.train_t5xxl is True
|
||||
|
||||
|
||||
def test_create_network_block_indices(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Test block indices parsing
|
||||
network = create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"train_double_block_indices": "0-2,4", "train_single_block_indices": "1,3"},
|
||||
)
|
||||
|
||||
# Assertions would depend on the exact implementation of parsing
|
||||
assert network.train_double_block_indices is not None
|
||||
assert network.train_single_block_indices is not None
|
||||
|
||||
double_block_indices = [
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
]
|
||||
single_block_indices = [
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
]
|
||||
|
||||
assert network.train_double_block_indices == double_block_indices
|
||||
assert network.train_single_block_indices == single_block_indices
|
||||
|
||||
|
||||
def test_create_network_loraplus_ratios(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Test LoRA+ ratios
|
||||
network = create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"loraplus_lr_ratio": 2.0, "loraplus_unet_lr_ratio": 1.5, "loraplus_text_encoder_lr_ratio": 1.0},
|
||||
)
|
||||
|
||||
# Verify LoRA+ ratios were set correctly
|
||||
assert network.loraplus_lr_ratio == 2.0
|
||||
assert network.loraplus_unet_lr_ratio == 1.5
|
||||
assert network.loraplus_text_encoder_lr_ratio == 1.0
|
||||
|
||||
|
||||
def test_create_network_loraplus_default_ratio(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
|
||||
# Test when only global LoRA+ ratio is provided
|
||||
network = create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
nueral_dropout=None,
|
||||
**{"loraplus_lr_ratio": 2.0},
|
||||
)
|
||||
|
||||
# Verify only global ratio is set
|
||||
assert network.loraplus_lr_ratio == 2.0
|
||||
assert network.loraplus_unet_lr_ratio is None
|
||||
assert network.loraplus_text_encoder_lr_ratio is None
|
||||
|
||||
|
||||
def test_create_network_invalid_inputs(mock_text_encoder, mock_flux):
|
||||
# Mock dependencies
|
||||
mock_ae = MagicMock()
|
||||
mock_text_encoders = [mock_text_encoder, mock_text_encoder]
|
||||
mock_flux = mock_flux
|
||||
|
||||
# Test invalid train_blocks
|
||||
with pytest.raises(AssertionError):
|
||||
create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"train_blocks": "invalid"},
|
||||
)
|
||||
|
||||
# Test invalid in_dims
|
||||
with pytest.raises(AssertionError):
|
||||
create_network(
|
||||
multiplier=1.0,
|
||||
network_dim=4,
|
||||
network_alpha=1.0,
|
||||
ae=mock_ae,
|
||||
text_encoders=mock_text_encoders,
|
||||
flux=mock_flux,
|
||||
neuron_dropout=None,
|
||||
**{"in_dims": "[1,2,3]"}, # Should be 5 dimensions
|
||||
)
|
||||
|
||||
|
||||
def test_lora_network_initialization(mock_text_encoder, mock_flux):
|
||||
# Test basic initialization with default parameters
|
||||
lora_network = LoRANetwork(text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux)
|
||||
|
||||
# Check basic attributes
|
||||
assert lora_network.multiplier == 1.0
|
||||
assert lora_network.lora_dim == 4
|
||||
assert lora_network.alpha == 1
|
||||
assert lora_network.train_blocks == "all"
|
||||
|
||||
# Check LoRA modules are created
|
||||
assert len(lora_network.text_encoder_loras) > 0
|
||||
assert len(lora_network.unet_loras) > 0
|
||||
|
||||
|
||||
def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_flux):
|
||||
# Test initialization with custom parameters
|
||||
lora_network = LoRANetwork(
|
||||
text_encoders=[mock_text_encoder],
|
||||
unet=mock_flux,
|
||||
multiplier=0.5,
|
||||
lora_dim=8,
|
||||
alpha=2.0,
|
||||
dropout=0.1,
|
||||
rank_dropout=0.05,
|
||||
module_dropout=0.02,
|
||||
train_blocks="single",
|
||||
split_qkv=True,
|
||||
)
|
||||
|
||||
# Verify custom parameters are set correctly
|
||||
assert lora_network.multiplier == 0.5
|
||||
assert lora_network.lora_dim == 8
|
||||
assert lora_network.alpha == 2.0
|
||||
assert lora_network.dropout == 0.1
|
||||
assert lora_network.rank_dropout == 0.05
|
||||
assert lora_network.module_dropout == 0.02
|
||||
assert lora_network.train_blocks == "single"
|
||||
assert lora_network.split_qkv is True
|
||||
|
||||
|
||||
def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux):
|
||||
# Test initialization with custom module dimensions
|
||||
modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8}
|
||||
modules_alpha = {"lora_te1_attns_0_layers_0": 2, "lora_unet_double_blocks_0_layers_0": 1}
|
||||
|
||||
lora_network = LoRANetwork(
|
||||
text_encoders=[mock_text_encoder, mock_text_encoder], unet=mock_flux, modules_dim=modules_dim, modules_alpha=modules_alpha
|
||||
)
|
||||
|
||||
# [LoRAModule(
|
||||
# (lora_down): Linear(in_features=10, out_features=8, bias=False)
|
||||
# (lora_up): Linear(in_features=8, out_features=15, bias=False)
|
||||
# (org_module): Linear(in_features=10, out_features=15, bias=True)
|
||||
# )]
|
||||
# [LoRAModule(
|
||||
# (lora_down): Linear(in_features=10, out_features=16, bias=False)
|
||||
# (lora_up): Linear(in_features=16, out_features=15, bias=False)
|
||||
# (org_module): Linear(in_features=10, out_features=15, bias=True)
|
||||
# )]
|
||||
|
||||
assert isinstance(lora_network.unet_loras[0].lora_down, torch.nn.Linear)
|
||||
assert isinstance(lora_network.unet_loras[0].lora_up, torch.nn.Linear)
|
||||
assert lora_network.unet_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_unet_double_blocks_0_layers_0"]
|
||||
assert lora_network.unet_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_unet_double_blocks_0_layers_0"]
|
||||
assert lora_network.unet_loras[0].alpha == modules_alpha["lora_unet_double_blocks_0_layers_0"]
|
||||
|
||||
assert isinstance(lora_network.text_encoder_loras[0].lora_down, torch.nn.Linear)
|
||||
assert isinstance(lora_network.text_encoder_loras[0].lora_up, torch.nn.Linear)
|
||||
assert lora_network.text_encoder_loras[0].lora_down.weight.data.shape[0] == modules_dim["lora_te1_attns_0_layers_0"]
|
||||
assert lora_network.text_encoder_loras[0].lora_up.weight.data.shape[1] == modules_dim["lora_te1_attns_0_layers_0"]
|
||||
assert lora_network.text_encoder_loras[0].alpha == modules_alpha["lora_te1_attns_0_layers_0"]
|
||||
Reference in New Issue
Block a user