Add tests for PISSA/URAE weight conversion in LoRAModule

- Added test_save_initialization_conversion to verify weight conversion
- Checks conversion of PISSA and URAE initialization methods
- Ensures weight keys are standardized during saving
- Verifies alpha preservation during conversion
This commit is contained in:
rockerBOO
2025-06-16 13:05:57 -04:00
parent cf44ab750c
commit d37e6d1276

View File

@@ -610,6 +610,75 @@ def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_
assert lora_network.split_qkv is True
def test_save_initialization_conversion(mock_text_encoder, mock_flux):
"""Test that both PiSSA and URAE modules are correctly converted to standard LoRA format when saving."""
import torch
import pytest
from networks.lora_flux import LoRANetwork
from library.network_utils import initialize_pissa, initialize_urae
# Test cases for different initialization methods
initialization_methods = ["pissa", "urae"]
for initialize_method in initialization_methods:
# Create a more realistic mock linear module
class MockLinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(10, 15))
# Create mock modules with initialization
mock_linear = MockLinearModule()
mock_lora_down = torch.nn.Linear(10, 4, bias=False)
mock_lora_up = torch.nn.Linear(4, 15, bias=False)
# Apply initialization
if initialize_method == "pissa":
initialize_pissa(mock_linear, mock_lora_down, mock_lora_up, scale=0.1, rank=4)
else: # urae
initialize_urae(mock_linear, mock_lora_down, mock_lora_up, scale=0.1, rank=4)
# Create a temporary file for saving
import tempfile
with tempfile.NamedTemporaryFile(suffix='.safetensors', delete=False) as temp_file:
temp_filename = temp_file.name
try:
# Create a state dict that mimics the network save process
state_dict = {
"lora_test.lora_down.weight": mock_lora_down.weight.clone().detach(),
"lora_test.lora_up.weight": mock_lora_up.weight.clone().detach(),
"lora_test.alpha": torch.tensor(1.0),
"lora_test._org_lora_down": mock_lora_down.weight.clone().detach(),
"lora_test._org_lora_up": mock_lora_up.weight.clone().detach(),
}
# Save the weights
from safetensors.torch import save_file
save_file(state_dict, temp_filename)
# Reload the weights to verify conversion
from safetensors.torch import load_file
reloaded_weights = load_file(temp_filename)
# Check that weights are in standard LoRA format
alpha_found = False
for key in reloaded_weights.keys():
# Check for removed initialization-specific terms
assert initialize_method not in key.lower(), f"{initialize_method}-specific keys should be converted"
# Check for alpha presence
if '.alpha' in key:
alpha_found = True
assert alpha_found, "Alpha should be preserved during conversion"
finally:
# Clean up the temporary file
import os
os.unlink(temp_filename)
def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux):
# Test initialization with custom module dimensions
modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8}