diff --git a/library/lora_util.py b/library/lora_util.py new file mode 100644 index 00000000..f1a6d004 --- /dev/null +++ b/library/lora_util.py @@ -0,0 +1,225 @@ +from collections.abc import MutableSequence +import re +import math +import warnings +from typing import Optional, Union + + +def parse_blocks(input_str: Optional[Union[str, float]], length=19, default: Optional[float]=0.0) -> MutableSequence[Optional[float]]: + """ + Parse different formats of block specifications and return a list of values. + + Args: + input_str (str): The input string after the '=' sign + length (int): The desired length of the output list (default: 19) + + Returns: + list: A list of float values with the specified length + """ + input_str = f"{input_str}" if not isinstance(input_str, str) else input_str.strip() + result = [default] * length # Initialize with default + + if input_str == "": + return [default] * length + + # Case: Single value (e.g., "1.0" or "-1.0") + if re.match(r'^-?\d+(\.\d+)?$', input_str): + value = float(input_str) + return [value] * length + + # Case: Explicit list (e.g., "[0,0,1,1,0.9,0.8,0.6]") + if input_str.startswith("[") and input_str.endswith("]"): + if input_str[1:-1].strip() == "": + return [default] * length + + # Use regex to properly split on commas while handling negative numbers + values = [float(x) for x in re.findall(r'-?\d+(?:\.\d+)?', input_str)] + # If list is shorter than required length, repeat the pattern + if len(values) < length: + values = (values * (length // len(values) + 1))[:length] + # If list is longer than required length, truncate + return values[:length] + + # Pre-process to handle function parameters with commas + # Replace function parameters with placeholders + function_params = {} + placeholder_counter = 0 + + def replace_function(match): + nonlocal placeholder_counter + func_with_params = match.group(0) + placeholder = f"FUNC_PLACEHOLDER_{placeholder_counter}" + function_params[placeholder] = func_with_params + placeholder_counter += 1 + return placeholder + + # Find function calls with parameters and replace them + preprocessed_str = re.sub(r'\w+\([^)]+\)', replace_function, input_str) + + # Case: Default value with specific overrides (e.g., "1.0,0:0.5") + parts = preprocessed_str.split(',') + default_value = default + + # Check if the first part is a default value (no colon) + if ':' not in parts[0] and re.match(r'^-?\d+(\.\d+)?$', parts[0]): + default_value = float(parts[0]) + parts = parts[1:] # Remove the default value from parts + # Fill the result with the default value + result = [default_value] * length + + # Process the remaining parts as ranges or single indices + for part in parts: + if ':' not in part: + continue # Skip parts without colon (should only be the default value) + + indices_part, value_part = part.split(':') + + # Restore any function placeholders + for placeholder, original in function_params.items(): + if placeholder in value_part: + value_part = value_part.replace(placeholder, original) + + # Handle range (e.g., "10-18" or "-5-10") + if '-' in indices_part and not indices_part.startswith('-'): + # This is a range with a dash (not just a negative number) + range_parts = indices_part.split('-', 1) # Split on first dash only + + # Handle potential negative values in the range + if range_parts[0] == '': + # Handle case like "-5-10" (from -5 to 10) + start_idx = int('-' + range_parts[1].split('-')[0]) + end_idx = int(range_parts[1].split('-')[1]) + else: + # Normal case like "5-10" or "-5-(-3)" + start_idx = int(range_parts[0]) + end_idx_str = range_parts[1] + + # Handle potentially complex end index expressions + if end_idx_str.startswith('(') and end_idx_str.endswith(')'): + # Handle expressions like "(-3)" + end_idx = eval(end_idx_str) + else: + print("end_idx_str", end_idx_str) + # If end str is blank, set to start idx + if end_idx_str == "": + warnings.warn("Range end was missing, setting to start of range") + end_idx = start_idx + else: + end_idx = int(end_idx_str) + + # Make sure indices are within bounds + start_idx = max(0, min(start_idx, length-1)) + end_idx = max(0, min(end_idx, length-1)) + range_length = end_idx - start_idx + 1 + + # Check if we have a function with parameters + # Checking function and 2 numbers (float and int) separated by , + # cos(0.2, 0.8), cos(0, 1.0), cos(1, 0.1) + func_match = re.match(r'(\w+)\((\d+|\d+\.\d+),(\d+|\d+\.\d+)\)', value_part) + if func_match: + func_name = func_match.group(1) + start_val = float(func_match.group(2)) + end_val = float(func_match.group(3)) + + if func_name == 'cos': + # Implement parameterized cosine + for i in range(range_length): + # Calculate position in the range from 0 to π (half a period) + position = i / (range_length - 1) * math.pi if range_length > 1 else 0 + # Cosine from 1 at 0 to 0 at π, scaled to requested range + normalized_value = (1 + math.cos(position)) / 2 + # Scale and shift to the requested start and end values + value = start_val + normalized_value * (end_val - start_val) + if start_idx + i < length: + result[start_idx + i] = value + + elif func_name == 'sin': + # Implement parameterized sine + for i in range(range_length): + # Calculate position in the range from 0 to π/2 (quarter period) + position = i / (range_length - 1) * (math.pi/2) if range_length > 1 else 0 + # Sine from 0 at 0 to 1 at π/2, scaled to requested range + normalized_value = math.sin(position) + # Scale and shift to the requested start and end values + value = start_val + normalized_value * (end_val - start_val) + if start_idx + i < length: + result[start_idx + i] = value + + elif func_name == 'linear': + # Implement parameterized linear function + for i in range(range_length): + # Linear interpolation from start_val to end_val + t = i / (range_length - 1) if range_length > 1 else 0 + value = start_val + t * (end_val - start_val) + if start_idx + i < length: + result[start_idx + i] = value + + elif func_name == 'reverse_linear': + # Implement parameterized reverse linear function + for i in range(range_length): + # Linear interpolation from end_val to start_val + t = i / (range_length - 1) if range_length > 1 else 0 + value = end_val + t * (start_val - end_val) + if start_idx + i < length: + result[start_idx + i] = value + + # Handle non-parameterized functions + elif value_part == 'cos': + # Default cosine from 1 to 0 + for i in range(range_length): + position = i / (range_length - 1) * math.pi if range_length > 1 else 0 + value = (1 + math.cos(position)) / 2 + if start_idx + i < length: + result[start_idx + i] = value + + elif value_part == 'sin': + # Default sine from 0 to 1 + for i in range(range_length): + position = i / (range_length - 1) * (math.pi/2) if range_length > 1 else 0 + value = math.sin(position) + if start_idx + i < length: + result[start_idx + i] = value + + elif value_part == 'linear': + # Default linear from 0 to 1 + for i in range(range_length): + value = i / (range_length - 1) if range_length > 1 else 0 + if start_idx + i < length: + result[start_idx + i] = value + + elif value_part == 'reverse_linear': + # Default reverse linear from 1 to 0 + for i in range(range_length): + value = 1 - (i / (range_length - 1) if range_length > 1 else 0) + if start_idx + i < length: + result[start_idx + i] = value + + else: + # Regular numeric value + try: + value = float(value_part) + for i in range(start_idx, end_idx + 1): + if 0 <= i < length: + result[i] = value + except ValueError: + warnings.warn(f"Could not parse value '{value_part}'") + + # Handle single index (e.g., "1") + else: + try: + index = int(indices_part) + if 0 <= index < length: + # Check if we have a function with parameters (unlikely for single index) + if '(' in value_part and ')' in value_part: + warnings.warn("Functions with parameters not supported for single indices: {part}") + continue + + # Assuming a single index won't have a function pattern, just a value + value = float(value_part) + result[index] = value + except ValueError: + raise RuntimeError(f"Could not parse index '{indices_part}'") + + return result + + diff --git a/tests/library/test_lora_util_blocks.py b/tests/library/test_lora_util_blocks.py new file mode 100644 index 00000000..f21fd1ef --- /dev/null +++ b/tests/library/test_lora_util_blocks.py @@ -0,0 +1,303 @@ +import pytest +import math +from library.lora_util import parse_blocks + + +def test_single_value(): + # Test single numeric value + result = parse_blocks("1.0") + assert len(result) == 19 + assert all(val == 1.0 for val in result), "set all values to 1.0 when default value is 1.0" + + # Test zero + result = parse_blocks("0") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "set all values to 0.0 when default value is 0" + + # Test negative value + result = parse_blocks("-0.5") + assert len(result) == 19 + assert all(val == -0.5 for val in result), "set all values to -0.5 when default value is -0.5" + + +def test_explicit_list(): + # Test exact length list + result = parse_blocks("[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]") + assert len(result) == 19 + assert result[0] == 0.1 + assert result[9] == 1.0 + assert result[18] == 0.1 + + # Test shorter list that repeats + result = parse_blocks("[0.0,0.5,1.0]") + assert len(result) == 19 + assert result[0] == 0.0 + assert result[1] == 0.5 + assert result[2] == 1.0 + assert result[3] == 0.0 # Pattern repeats + assert result[4] == 0.5 + + # Test longer list that gets truncated + result = parse_blocks("[" + ",".join(["0.5"] * 25) + "]") + assert len(result) == 19 + assert all(val == 0.5 for val in result) + + +def test_default_with_overrides(): + # Test default value with single index override + result = parse_blocks("1.0,0:0.5") + assert len(result) == 19 + assert result[0] == 0.5 + assert all(val == 1.0 for val in result[1:]) + + # Test default with multiple index overrides + result = parse_blocks("0.5,1:0.7,5:0.9,10:0.3") + assert len(result) == 19 + assert result[0] == 0.5 # Default value + assert result[1] == 0.7 # Override + assert result[5] == 0.9 # Override + assert result[10] == 0.3 # Override + assert result[18] == 0.5 # Default value + + # Test without default value (should use 0.0) + result = parse_blocks("3:0.8") + assert len(result) == 19 + assert result[3] == 0.8 + assert all(val == 0.0 for i, val in enumerate(result) if i != 3) + + +def test_range_overrides(): + # Test simple range + result = parse_blocks("1-5:0.7") + assert len(result) == 19 + assert all(result[i] == 0.7 for i in range(1, 6)) + assert all(val == 0.0 for i, val in enumerate(result) if i < 1 or i > 5) + + # Test multiple ranges + result = parse_blocks("0.1,1-3:0.5,7-9:0.8") + assert len(result) == 19 + assert all(result[i] == 0.5 for i in range(1, 4)) + assert all(result[i] == 0.8 for i in range(7, 10)) + assert result[0] == 0.1 # Default + assert result[6] == 0.1 # Default + assert result[18] == 0.1 # Default + + +def test_cos_function(): + # Test cos over range + result = parse_blocks("1-5:cos") + assert len(result) == 19 + # Calculate expected values for cosine function + expected_cos = [(1 + math.cos(i / (5 - 1) * math.pi)) / 2 for i in range(5)] + for i in range(1, 6): + assert result[i] == pytest.approx(expected_cos[i - 1]) + + # Test parameterized cos + result = parse_blocks("3-7:cos(0.2,0.8)") + assert len(result) == 19 + # Cos goes from 1 to 0 over π, scaled to range 0.2 to 0.8 + for i in range(5): + normalized = (1 + math.cos(i / (5 - 1) * math.pi)) / 2 + expected = 0.2 + normalized * (0.8 - 0.2) + assert result[i + 3] == pytest.approx(expected) + + +def test_sin_function(): + # Test sin over range + result = parse_blocks("2-6:sin") + assert len(result) == 19 + # Calculate expected values for sine function + expected_sin = [math.sin(i / (6 - 2) * (math.pi / 2)) for i in range(5)] + for i in range(2, 7): + assert result[i] == pytest.approx(expected_sin[i - 2]) + + # Test parameterized sin + result = parse_blocks("4-8:sin(0.3,0.9)") + assert len(result) == 19 + # Sin goes from 0 to 1 over π/2, scaled to range 0.3 to 0.9 + for i in range(5): + normalized = math.sin(i / (5 - 1) * (math.pi / 2)) + expected = 0.3 + normalized * (0.9 - 0.3) + assert result[i + 4] == pytest.approx(expected) + + +def test_linear_function(): + # Test linear over range + result = parse_blocks("3-7:linear") + assert len(result) == 19 + # Calculate expected values for linear function (0 to 1) + expected_linear = [i / (7 - 3) for i in range(5)] + for i in range(3, 8): + assert result[i] == pytest.approx(expected_linear[i - 3]) + + # Test parameterized linear + result = parse_blocks("5-9:linear(0.4,0.7)") + assert len(result) == 19 + # Linear goes from 0.4 to 0.7 + for i in range(5): + t = i / 4 # normalized position + expected = 0.4 + t * (0.7 - 0.4) + assert result[i + 5] == pytest.approx(expected) + + +def test_reverse_linear_function(): + # Test reverse_linear over range + result = parse_blocks("2-6:reverse_linear") + assert len(result) == 19 + # Calculate expected values for reverse linear function (1 to 0) + expected_reverse = [1 - i / (6 - 2) for i in range(5)] + for i in range(2, 7): + assert result[i] == pytest.approx(expected_reverse[i - 2]) + + # Test parameterized reverse_linear + result = parse_blocks("10-15:reverse_linear(0.8,0.2)") + assert len(result) == 19 + # Reverse linear goes from 0.2 to 0.8 (reversed) + for i in range(6): + t = i / 5 # normalized position + expected = 0.2 + t * (0.8 - 0.2) + assert result[i + 10] == pytest.approx(expected) + + +def test_custom_length(): + # Test with custom length + result = parse_blocks("1.0", length=5) + assert len(result) == 5 + assert all(val == 1.0 for val in result) + + # Test list with custom length + result = parse_blocks("[0.1,0.2,0.3]", length=10) + assert len(result) == 10 + assert result[0] == 0.1 + assert result[3] == 0.1 # Pattern repeats + + # Test ranges with custom length + result = parse_blocks("1-3:0.5", length=7) + assert len(result) == 7 + assert all(result[i] == 0.5 for i in range(1, 4)) + assert result[0] == 0.0 + assert result[6] == 0.0 + + +def test_custom_default(): + # Test with custom default value + result = parse_blocks("1:0.5", default=0.2) + assert len(result) == 19 + assert result[1] == 0.5 + assert result[0] == 0.2 + assert result[18] == 0.2 + + # Test overriding default value + result = parse_blocks("0.7,1:0.5", default=0.2) + assert len(result) == 19 + assert result[1] == 0.5 + assert result[0] == 0.7 # Explicitly set default + assert result[18] == 0.7 + + +def test_out_of_bounds_indices(): + # Test negative indices (should be ignored) + result = parse_blocks("-5:0.9") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "Negative index should be ignored" + + # Test indices beyond length + result = parse_blocks("25:0.8") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "Indices above the max length should be ignored" + + # Test range partially out of bounds + result = parse_blocks("17-22:0.7") + assert len(result) == 19 + assert result[17] == 0.7 + assert result[18] == 0.7 + # Indices 19-22 would be out of bounds + + +def test_mixed_patterns(): + # Test combining different formats + result = parse_blocks("0.3,2:0.8,5-8:cos,10-15:linear(0.1,0.9)") + assert len(result) == 19 + assert result[0] == 0.3 # Default + assert result[2] == 0.8 # Single index + + # Check cos values + cos_range = range(5, 9) + expected_cos = [(1 + math.cos(i / (8 - 5) * math.pi)) / 2 for i in range(4)] + for i, idx in enumerate(cos_range): + assert result[idx] == pytest.approx(expected_cos[i]) + + # Check linear values + linear_range = range(10, 16) + for i, idx in enumerate(linear_range): + t = i / 5 # normalized position + expected = 0.1 + t * (0.9 - 0.1) + assert result[idx] == pytest.approx(expected) + + +def test_edge_cases(): + # Test empty string + result = parse_blocks("") + assert len(result) == 19 + assert all(val == 0.0 for val in result) + + # Test whitespace + result = parse_blocks(" ") + assert len(result) == 19 + assert all(val == 0.0 for val in result) + + # Test empty list + result = parse_blocks("[]") + assert len(result) == 19 + assert all(val == 0.0 for val in result) + + # Test single-item range + result = parse_blocks("5-5:0.7") + assert len(result) == 19 + assert result[5] == 0.7 + assert result[4] == 0.0 + assert result[6] == 0.0 + + # Test function with single-item range + result = parse_blocks("7-7:cos") + assert len(result) == 19 + assert result[7] == 1.0 # When range is single point, cos at position 0 is 1 + + # Test overlapping ranges + result = parse_blocks("1-5:0.3,3-7:0.8") + assert len(result) == 19 + assert result[1] == 0.3 + assert result[2] == 0.3 + assert result[3] == 0.8 # Later definition overwrites + assert result[4] == 0.8 # Later definition overwrites + assert result[5] == 0.8 # Later definition overwrites + assert result[7] == 0.8 + assert result[8] == 0.0 + + +def test_malformed_input(): + # Test malformed list + result = parse_blocks("[0.1,0.2,") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "malformed list" + + # Test invalid end range + result = parse_blocks("5-:0.7") + assert len(result) == 19 + assert result[5] == 0.7 + assert result[6] == 0.0 + + # Test invalid start range, indices should never be negative + result = parse_blocks("-5:0.7") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "invalid start range, indices should never be negative" + + # Test invalid function + result = parse_blocks("1-5:unknown_func") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "Function name not recognized" + + # Test invalid function parameters + result = parse_blocks("1-5:cos(invalid,0.8)") + assert len(result) == 19 + assert all(val == 0.0 for val in result), "Invalid parameters"