mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Add tests for PISSA/URAE weight conversion in LoRAModule
- Added test_save_initialization_conversion to verify weight conversion - Checks conversion of PISSA and URAE initialization methods - Ensures weight keys are standardized during saving - Verifies alpha preservation during conversion
This commit is contained in:
@@ -610,6 +610,75 @@ def test_lora_network_initialization_with_custom_params(mock_text_encoder, mock_
|
||||
assert lora_network.split_qkv is True
|
||||
|
||||
|
||||
def test_save_initialization_conversion(mock_text_encoder, mock_flux):
|
||||
"""Test that both PiSSA and URAE modules are correctly converted to standard LoRA format when saving."""
|
||||
import torch
|
||||
import pytest
|
||||
from networks.lora_flux import LoRANetwork
|
||||
from library.network_utils import initialize_pissa, initialize_urae
|
||||
|
||||
# Test cases for different initialization methods
|
||||
initialization_methods = ["pissa", "urae"]
|
||||
|
||||
for initialize_method in initialization_methods:
|
||||
# Create a more realistic mock linear module
|
||||
class MockLinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.randn(10, 15))
|
||||
|
||||
# Create mock modules with initialization
|
||||
mock_linear = MockLinearModule()
|
||||
mock_lora_down = torch.nn.Linear(10, 4, bias=False)
|
||||
mock_lora_up = torch.nn.Linear(4, 15, bias=False)
|
||||
|
||||
# Apply initialization
|
||||
if initialize_method == "pissa":
|
||||
initialize_pissa(mock_linear, mock_lora_down, mock_lora_up, scale=0.1, rank=4)
|
||||
else: # urae
|
||||
initialize_urae(mock_linear, mock_lora_down, mock_lora_up, scale=0.1, rank=4)
|
||||
|
||||
# Create a temporary file for saving
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(suffix='.safetensors', delete=False) as temp_file:
|
||||
temp_filename = temp_file.name
|
||||
|
||||
try:
|
||||
# Create a state dict that mimics the network save process
|
||||
state_dict = {
|
||||
"lora_test.lora_down.weight": mock_lora_down.weight.clone().detach(),
|
||||
"lora_test.lora_up.weight": mock_lora_up.weight.clone().detach(),
|
||||
"lora_test.alpha": torch.tensor(1.0),
|
||||
"lora_test._org_lora_down": mock_lora_down.weight.clone().detach(),
|
||||
"lora_test._org_lora_up": mock_lora_up.weight.clone().detach(),
|
||||
}
|
||||
|
||||
# Save the weights
|
||||
from safetensors.torch import save_file
|
||||
save_file(state_dict, temp_filename)
|
||||
|
||||
# Reload the weights to verify conversion
|
||||
from safetensors.torch import load_file
|
||||
reloaded_weights = load_file(temp_filename)
|
||||
|
||||
# Check that weights are in standard LoRA format
|
||||
alpha_found = False
|
||||
for key in reloaded_weights.keys():
|
||||
# Check for removed initialization-specific terms
|
||||
assert initialize_method not in key.lower(), f"{initialize_method}-specific keys should be converted"
|
||||
|
||||
# Check for alpha presence
|
||||
if '.alpha' in key:
|
||||
alpha_found = True
|
||||
|
||||
assert alpha_found, "Alpha should be preserved during conversion"
|
||||
|
||||
finally:
|
||||
# Clean up the temporary file
|
||||
import os
|
||||
os.unlink(temp_filename)
|
||||
|
||||
|
||||
def test_lora_network_initialization_with_custom_modules_dim(mock_text_encoder, mock_flux):
|
||||
# Test initialization with custom module dimensions
|
||||
modules_dim = {"lora_te1_attns_0_layers_0": 16, "lora_unet_double_blocks_0_layers_0": 8}
|
||||
|
||||
Reference in New Issue
Block a user