mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
304 lines
9.8 KiB
Python
304 lines
9.8 KiB
Python
import pytest
|
|
import math
|
|
from library.lora_util import parse_blocks
|
|
|
|
|
|
def test_single_value():
|
|
# Test single numeric value
|
|
result = parse_blocks("1.0")
|
|
assert len(result) == 19
|
|
assert all(val == 1.0 for val in result), "set all values to 1.0 when default value is 1.0"
|
|
|
|
# Test zero
|
|
result = parse_blocks("0")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "set all values to 0.0 when default value is 0"
|
|
|
|
# Test negative value
|
|
result = parse_blocks("-0.5")
|
|
assert len(result) == 19
|
|
assert all(val == -0.5 for val in result), "set all values to -0.5 when default value is -0.5"
|
|
|
|
|
|
def test_explicit_list():
|
|
# Test exact length list
|
|
result = parse_blocks("[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]")
|
|
assert len(result) == 19
|
|
assert result[0] == 0.1
|
|
assert result[9] == 1.0
|
|
assert result[18] == 0.1
|
|
|
|
# Test shorter list that repeats
|
|
result = parse_blocks("[0.0,0.5,1.0]")
|
|
assert len(result) == 19
|
|
assert result[0] == 0.0
|
|
assert result[1] == 0.5
|
|
assert result[2] == 1.0
|
|
assert result[3] == 0.0 # Pattern repeats
|
|
assert result[4] == 0.5
|
|
|
|
# Test longer list that gets truncated
|
|
result = parse_blocks("[" + ",".join(["0.5"] * 25) + "]")
|
|
assert len(result) == 19
|
|
assert all(val == 0.5 for val in result)
|
|
|
|
|
|
def test_default_with_overrides():
|
|
# Test default value with single index override
|
|
result = parse_blocks("1.0,0:0.5")
|
|
assert len(result) == 19
|
|
assert result[0] == 0.5
|
|
assert all(val == 1.0 for val in result[1:])
|
|
|
|
# Test default with multiple index overrides
|
|
result = parse_blocks("0.5,1:0.7,5:0.9,10:0.3")
|
|
assert len(result) == 19
|
|
assert result[0] == 0.5 # Default value
|
|
assert result[1] == 0.7 # Override
|
|
assert result[5] == 0.9 # Override
|
|
assert result[10] == 0.3 # Override
|
|
assert result[18] == 0.5 # Default value
|
|
|
|
# Test without default value (should use 0.0)
|
|
result = parse_blocks("3:0.8")
|
|
assert len(result) == 19
|
|
assert result[3] == 0.8
|
|
assert all(val == 0.0 for i, val in enumerate(result) if i != 3)
|
|
|
|
|
|
def test_range_overrides():
|
|
# Test simple range
|
|
result = parse_blocks("1-5:0.7")
|
|
assert len(result) == 19
|
|
assert all(result[i] == 0.7 for i in range(1, 6))
|
|
assert all(val == 0.0 for i, val in enumerate(result) if i < 1 or i > 5)
|
|
|
|
# Test multiple ranges
|
|
result = parse_blocks("0.1,1-3:0.5,7-9:0.8")
|
|
assert len(result) == 19
|
|
assert all(result[i] == 0.5 for i in range(1, 4))
|
|
assert all(result[i] == 0.8 for i in range(7, 10))
|
|
assert result[0] == 0.1 # Default
|
|
assert result[6] == 0.1 # Default
|
|
assert result[18] == 0.1 # Default
|
|
|
|
|
|
def test_cos_function():
|
|
# Test cos over range
|
|
result = parse_blocks("1-5:cos")
|
|
assert len(result) == 19
|
|
# Calculate expected values for cosine function
|
|
expected_cos = [(1 + math.cos(i / (5 - 1) * math.pi)) / 2 for i in range(5)]
|
|
for i in range(1, 6):
|
|
assert result[i] == pytest.approx(expected_cos[i - 1])
|
|
|
|
# Test parameterized cos
|
|
result = parse_blocks("3-7:cos(0.2,0.8)")
|
|
assert len(result) == 19
|
|
# Cos goes from 1 to 0 over π, scaled to range 0.2 to 0.8
|
|
for i in range(5):
|
|
normalized = (1 + math.cos(i / (5 - 1) * math.pi)) / 2
|
|
expected = 0.2 + normalized * (0.8 - 0.2)
|
|
assert result[i + 3] == pytest.approx(expected)
|
|
|
|
|
|
def test_sin_function():
|
|
# Test sin over range
|
|
result = parse_blocks("2-6:sin")
|
|
assert len(result) == 19
|
|
# Calculate expected values for sine function
|
|
expected_sin = [math.sin(i / (6 - 2) * (math.pi / 2)) for i in range(5)]
|
|
for i in range(2, 7):
|
|
assert result[i] == pytest.approx(expected_sin[i - 2])
|
|
|
|
# Test parameterized sin
|
|
result = parse_blocks("4-8:sin(0.3,0.9)")
|
|
assert len(result) == 19
|
|
# Sin goes from 0 to 1 over π/2, scaled to range 0.3 to 0.9
|
|
for i in range(5):
|
|
normalized = math.sin(i / (5 - 1) * (math.pi / 2))
|
|
expected = 0.3 + normalized * (0.9 - 0.3)
|
|
assert result[i + 4] == pytest.approx(expected)
|
|
|
|
|
|
def test_linear_function():
|
|
# Test linear over range
|
|
result = parse_blocks("3-7:linear")
|
|
assert len(result) == 19
|
|
# Calculate expected values for linear function (0 to 1)
|
|
expected_linear = [i / (7 - 3) for i in range(5)]
|
|
for i in range(3, 8):
|
|
assert result[i] == pytest.approx(expected_linear[i - 3])
|
|
|
|
# Test parameterized linear
|
|
result = parse_blocks("5-9:linear(0.4,0.7)")
|
|
assert len(result) == 19
|
|
# Linear goes from 0.4 to 0.7
|
|
for i in range(5):
|
|
t = i / 4 # normalized position
|
|
expected = 0.4 + t * (0.7 - 0.4)
|
|
assert result[i + 5] == pytest.approx(expected)
|
|
|
|
|
|
def test_reverse_linear_function():
|
|
# Test reverse_linear over range
|
|
result = parse_blocks("2-6:reverse_linear")
|
|
assert len(result) == 19
|
|
# Calculate expected values for reverse linear function (1 to 0)
|
|
expected_reverse = [1 - i / (6 - 2) for i in range(5)]
|
|
for i in range(2, 7):
|
|
assert result[i] == pytest.approx(expected_reverse[i - 2])
|
|
|
|
# Test parameterized reverse_linear
|
|
result = parse_blocks("10-15:reverse_linear(0.8,0.2)")
|
|
assert len(result) == 19
|
|
# Reverse linear goes from 0.2 to 0.8 (reversed)
|
|
for i in range(6):
|
|
t = i / 5 # normalized position
|
|
expected = 0.2 + t * (0.8 - 0.2)
|
|
assert result[i + 10] == pytest.approx(expected)
|
|
|
|
|
|
def test_custom_length():
|
|
# Test with custom length
|
|
result = parse_blocks("1.0", length=5)
|
|
assert len(result) == 5
|
|
assert all(val == 1.0 for val in result)
|
|
|
|
# Test list with custom length
|
|
result = parse_blocks("[0.1,0.2,0.3]", length=10)
|
|
assert len(result) == 10
|
|
assert result[0] == 0.1
|
|
assert result[3] == 0.1 # Pattern repeats
|
|
|
|
# Test ranges with custom length
|
|
result = parse_blocks("1-3:0.5", length=7)
|
|
assert len(result) == 7
|
|
assert all(result[i] == 0.5 for i in range(1, 4))
|
|
assert result[0] == 0.0
|
|
assert result[6] == 0.0
|
|
|
|
|
|
def test_custom_default():
|
|
# Test with custom default value
|
|
result = parse_blocks("1:0.5", default=0.2)
|
|
assert len(result) == 19
|
|
assert result[1] == 0.5
|
|
assert result[0] == 0.2
|
|
assert result[18] == 0.2
|
|
|
|
# Test overriding default value
|
|
result = parse_blocks("0.7,1:0.5", default=0.2)
|
|
assert len(result) == 19
|
|
assert result[1] == 0.5
|
|
assert result[0] == 0.7 # Explicitly set default
|
|
assert result[18] == 0.7
|
|
|
|
|
|
def test_out_of_bounds_indices():
|
|
# Test negative indices (should be ignored)
|
|
result = parse_blocks("-5:0.9")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "Negative index should be ignored"
|
|
|
|
# Test indices beyond length
|
|
result = parse_blocks("25:0.8")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "Indices above the max length should be ignored"
|
|
|
|
# Test range partially out of bounds
|
|
result = parse_blocks("17-22:0.7")
|
|
assert len(result) == 19
|
|
assert result[17] == 0.7
|
|
assert result[18] == 0.7
|
|
# Indices 19-22 would be out of bounds
|
|
|
|
|
|
def test_mixed_patterns():
|
|
# Test combining different formats
|
|
result = parse_blocks("0.3,2:0.8,5-8:cos,10-15:linear(0.1,0.9)")
|
|
assert len(result) == 19
|
|
assert result[0] == 0.3 # Default
|
|
assert result[2] == 0.8 # Single index
|
|
|
|
# Check cos values
|
|
cos_range = range(5, 9)
|
|
expected_cos = [(1 + math.cos(i / (8 - 5) * math.pi)) / 2 for i in range(4)]
|
|
for i, idx in enumerate(cos_range):
|
|
assert result[idx] == pytest.approx(expected_cos[i])
|
|
|
|
# Check linear values
|
|
linear_range = range(10, 16)
|
|
for i, idx in enumerate(linear_range):
|
|
t = i / 5 # normalized position
|
|
expected = 0.1 + t * (0.9 - 0.1)
|
|
assert result[idx] == pytest.approx(expected)
|
|
|
|
|
|
def test_edge_cases():
|
|
# Test empty string
|
|
result = parse_blocks("")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result)
|
|
|
|
# Test whitespace
|
|
result = parse_blocks(" ")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result)
|
|
|
|
# Test empty list
|
|
result = parse_blocks("[]")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result)
|
|
|
|
# Test single-item range
|
|
result = parse_blocks("5-5:0.7")
|
|
assert len(result) == 19
|
|
assert result[5] == 0.7
|
|
assert result[4] == 0.0
|
|
assert result[6] == 0.0
|
|
|
|
# Test function with single-item range
|
|
result = parse_blocks("7-7:cos")
|
|
assert len(result) == 19
|
|
assert result[7] == 1.0 # When range is single point, cos at position 0 is 1
|
|
|
|
# Test overlapping ranges
|
|
result = parse_blocks("1-5:0.3,3-7:0.8")
|
|
assert len(result) == 19
|
|
assert result[1] == 0.3
|
|
assert result[2] == 0.3
|
|
assert result[3] == 0.8 # Later definition overwrites
|
|
assert result[4] == 0.8 # Later definition overwrites
|
|
assert result[5] == 0.8 # Later definition overwrites
|
|
assert result[7] == 0.8
|
|
assert result[8] == 0.0
|
|
|
|
|
|
def test_malformed_input():
|
|
# Test malformed list
|
|
result = parse_blocks("[0.1,0.2,")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "malformed list"
|
|
|
|
# Test invalid end range
|
|
result = parse_blocks("5-:0.7")
|
|
assert len(result) == 19
|
|
assert result[5] == 0.7
|
|
assert result[6] == 0.0
|
|
|
|
# Test invalid start range, indices should never be negative
|
|
result = parse_blocks("-5:0.7")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "invalid start range, indices should never be negative"
|
|
|
|
# Test invalid function
|
|
result = parse_blocks("1-5:unknown_func")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "Function name not recognized"
|
|
|
|
# Test invalid function parameters
|
|
result = parse_blocks("1-5:cos(invalid,0.8)")
|
|
assert len(result) == 19
|
|
assert all(val == 0.0 for val in result), "Invalid parameters"
|