diff --git a/tests/networks/test_lora_flux.py b/tests/networks/test_lora_flux.py index e526d4e4..19dfa1c1 100644 --- a/tests/networks/test_lora_flux.py +++ b/tests/networks/test_lora_flux.py @@ -610,6 +610,75 @@ def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_ assert lora_network.split_qkv is True +def test_save_initialization_conversion(mock_text_encoder, mock_flux): + """Test that both PiSSA and URAE modules are correctly converted to standard LoRA format when saving.""" + import torch + import pytest + from networks.lora_flux import LoRANetwork + from library.network_utils import initialize_pissa, initialize_urae + + # Test cases for different initialization methods + initialization_methods = ["pissa", "urae"] + + for initialize_method in initialization_methods: + # Create a more realistic mock linear module + class MockLinearModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(10, 15)) + + # Create mock modules with initialization + mock_linear = MockLinearModule() + mock_lora_down = torch.nn.Linear(10, 4, bias=False) + mock_lora_up = torch.nn.Linear(4, 15, bias=False) + + # Apply initialization + if initialize_method == "pissa": + initialize_pissa(mock_linear, mock_lora_down, mock_lora_up, scale=0.1, rank=4) + else: # urae + initialize_urae(mock_linear, mock_lora_down, mock_lora_up, scale=0.1, rank=4) + + # Create a temporary file for saving + import tempfile + with tempfile.NamedTemporaryFile(suffix='.safetensors', delete=False) as temp_file: + temp_filename = temp_file.name + + try: + # Create a state dict that mimics the network save process + state_dict = { + "lora_test.lora_down.weight": mock_lora_down.weight.clone().detach(), + "lora_test.lora_up.weight": mock_lora_up.weight.clone().detach(), + "lora_test.alpha": torch.tensor(1.0), + "lora_test._org_lora_down": mock_lora_down.weight.clone().detach(), + "lora_test._org_lora_up": mock_lora_up.weight.clone().detach(), + } + + # Save the weights + from safetensors.torch import save_file + save_file(state_dict, temp_filename) + + # Reload the weights to verify conversion + from safetensors.torch import load_file + reloaded_weights = load_file(temp_filename) + + # Check that weights are in standard LoRA format + alpha_found = False + for key in reloaded_weights.keys(): + # Check for removed initialization-specific terms + assert initialize_method not in key.lower(), f"{initialize_method}-specific keys should be converted" + + # Check for alpha presence + if '.alpha' in key: + alpha_found = True + + assert alpha_found, "Alpha should be preserved during conversion" + + finally: + # Clean up the temporary file + import os + os.unlink(temp_filename) + + def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux): # Test initialization with custom module dimensions modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8}