Add comprehensive tests for URAE initialization

- Introduced test_urae_initialization_with_level_variations()
- Covered different precision levels
- Tested various rank and configuration scenarios
- Expanded test coverage for URAE initialization
This commit is contained in:
rockerBOO
2025-06-16 12:55:44 -04:00
parent d635838af2
commit 925a4a4df0

View File

@@ -1,7 +1,8 @@
import pytest
import torch
from typing import Optional, Dict, Any
from library.network_utils import convert_urae_to_standard_lora
from library.network_utils import convert_urae_to_standard_lora, initialize_urae
class TestURAE:
@@ -201,3 +202,86 @@ class TestURAE:
max_possible_rank = min(dim_in, dim_out)
assert lora_up.shape[1] <= max_possible_rank
assert lora_down.shape[0] <= max_possible_rank
def create_mock_urae_environment(level: Optional[str] = None) -> Dict[str, Any]:
"""Create a mock environment simulating different URAE initialization levels."""
base_config = {
"precision": torch.float32,
"rank_multiplier": 1.0,
"scale_factor": 0.1,
"use_minor_components": False
}
urae_levels = {
"low": {
"precision": torch.float16,
"rank_multiplier": 0.5,
"scale_factor": 0.05,
"use_minor_components": True
},
"medium": {
"precision": torch.float32,
"rank_multiplier": 1.0,
"scale_factor": 0.1,
"use_minor_components": False
},
"high": {
"precision": torch.float64,
"rank_multiplier": 2.0,
"scale_factor": 0.2,
"use_minor_components": False
}
}
if level and level in urae_levels:
base_config.update(urae_levels[level])
return base_config
def test_urae_initialization_with_level_variations():
"""Test URAE initialization with different levels and configurations."""
import torch
# Test URAE levels
urae_levels = [None, "low", "medium", "high"]
for level in urae_levels:
# Get URAE-specific configuration
urae_config = create_mock_urae_environment(level)
# Adjust input sizes based on rank multiplier
input_dim = int(100 * urae_config["rank_multiplier"])
output_dim = int(50 * urae_config["rank_multiplier"])
rank = int(10 * urae_config["rank_multiplier"])
# Create modules
org_module = torch.nn.Linear(input_dim, output_dim)
lora_down = torch.nn.Linear(input_dim, rank)
lora_up = torch.nn.Linear(rank, output_dim)
# Apply precision
org_module = org_module.to(dtype=urae_config["precision"])
lora_down = lora_down.to(dtype=urae_config["precision"])
lora_up = lora_up.to(dtype=urae_config["precision"])
# Apply device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
org_module = org_module.to(device)
lora_down = lora_down.to(device)
lora_up = lora_up.to(device)
# Test initialization
try:
initialize_urae(
org_module,
lora_down,
lora_up,
scale=urae_config["scale_factor"],
rank=rank,
)
except Exception as e:
pytest.fail(f"URAE initialization failed for level {level}: {e}")