mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Add comprehensive tests for URAE initialization
- Introduced test_urae_initialization_with_level_variations() - Covered different precision levels - Tested various rank and configuration scenarios - Expanded test coverage for URAE initialization
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from library.network_utils import convert_urae_to_standard_lora
|
||||
from library.network_utils import convert_urae_to_standard_lora, initialize_urae
|
||||
|
||||
|
||||
class TestURAE:
|
||||
@@ -201,3 +202,86 @@ class TestURAE:
|
||||
max_possible_rank = min(dim_in, dim_out)
|
||||
assert lora_up.shape[1] <= max_possible_rank
|
||||
assert lora_down.shape[0] <= max_possible_rank
|
||||
|
||||
|
||||
def create_mock_urae_environment(level: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Create a mock environment simulating different URAE initialization levels."""
|
||||
base_config = {
|
||||
"precision": torch.float32,
|
||||
"rank_multiplier": 1.0,
|
||||
"scale_factor": 0.1,
|
||||
"use_minor_components": False
|
||||
}
|
||||
|
||||
urae_levels = {
|
||||
"low": {
|
||||
"precision": torch.float16,
|
||||
"rank_multiplier": 0.5,
|
||||
"scale_factor": 0.05,
|
||||
"use_minor_components": True
|
||||
},
|
||||
"medium": {
|
||||
"precision": torch.float32,
|
||||
"rank_multiplier": 1.0,
|
||||
"scale_factor": 0.1,
|
||||
"use_minor_components": False
|
||||
},
|
||||
"high": {
|
||||
"precision": torch.float64,
|
||||
"rank_multiplier": 2.0,
|
||||
"scale_factor": 0.2,
|
||||
"use_minor_components": False
|
||||
}
|
||||
}
|
||||
|
||||
if level and level in urae_levels:
|
||||
base_config.update(urae_levels[level])
|
||||
|
||||
return base_config
|
||||
|
||||
|
||||
def test_urae_initialization_with_level_variations():
|
||||
"""Test URAE initialization with different levels and configurations."""
|
||||
import torch
|
||||
|
||||
# Test URAE levels
|
||||
urae_levels = [None, "low", "medium", "high"]
|
||||
|
||||
for level in urae_levels:
|
||||
# Get URAE-specific configuration
|
||||
urae_config = create_mock_urae_environment(level)
|
||||
|
||||
# Adjust input sizes based on rank multiplier
|
||||
input_dim = int(100 * urae_config["rank_multiplier"])
|
||||
output_dim = int(50 * urae_config["rank_multiplier"])
|
||||
rank = int(10 * urae_config["rank_multiplier"])
|
||||
|
||||
# Create modules
|
||||
org_module = torch.nn.Linear(input_dim, output_dim)
|
||||
lora_down = torch.nn.Linear(input_dim, rank)
|
||||
lora_up = torch.nn.Linear(rank, output_dim)
|
||||
|
||||
# Apply precision
|
||||
org_module = org_module.to(dtype=urae_config["precision"])
|
||||
lora_down = lora_down.to(dtype=urae_config["precision"])
|
||||
lora_up = lora_up.to(dtype=urae_config["precision"])
|
||||
|
||||
# Apply device
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
org_module = org_module.to(device)
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
# Test initialization
|
||||
try:
|
||||
initialize_urae(
|
||||
org_module,
|
||||
lora_down,
|
||||
lora_up,
|
||||
scale=urae_config["scale_factor"],
|
||||
rank=rank,
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"URAE initialization failed for level {level}: {e}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user