mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Compare commits
4 Commits
748566a9a0
...
333215a805
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
333215a805 | ||
|
|
fa53f71ec0 | ||
|
|
ef3a110ae1 | ||
|
|
0392a57210 |
@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
|
||||
|
||||
### 更新履歴
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
|
||||
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。
|
||||
|
||||
@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
|
||||
|
||||
### Change History
|
||||
|
||||
- **Version 0.10.3 (2026-04-02):**
|
||||
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
|
||||
|
||||
- **Version 0.10.2 (2026-03-30):**
|
||||
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
|
||||
- Please refer to the [documentation](./docs/train_leco.md) for details.
|
||||
|
||||
@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
use_fp32: bool = False,
|
||||
):
|
||||
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
|
||||
if self.use_adaln_lora:
|
||||
assert adaln_lora_B_T_3D is not None
|
||||
@@ -863,11 +863,11 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
if use_fp32:
|
||||
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
@@ -959,6 +959,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
use_fp32: bool = False,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -972,6 +973,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -994,6 +996,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
|
||||
emb_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
use_fp32,
|
||||
rope_emb_L_1_1_D,
|
||||
adaln_lora_B_T_3D,
|
||||
extra_per_block_pos_emb,
|
||||
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
|
||||
|
||||
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||
|
||||
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
|
||||
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
|
||||
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.submit_move_blocks(self.blocks, block_idx)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
|
||||
225
library/lora_util.py
Normal file
225
library/lora_util.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from collections.abc import MutableSequence
|
||||
import re
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def parse_blocks(input_str: Optional[Union[str, float]], length=19, default: Optional[float]=0.0) -> MutableSequence[Optional[float]]:
|
||||
"""
|
||||
Parse different formats of block specifications and return a list of values.
|
||||
|
||||
Args:
|
||||
input_str (str): The input string after the '=' sign
|
||||
length (int): The desired length of the output list (default: 19)
|
||||
|
||||
Returns:
|
||||
list: A list of float values with the specified length
|
||||
"""
|
||||
input_str = f"{input_str}" if not isinstance(input_str, str) else input_str.strip()
|
||||
result = [default] * length # Initialize with default
|
||||
|
||||
if input_str == "":
|
||||
return [default] * length
|
||||
|
||||
# Case: Single value (e.g., "1.0" or "-1.0")
|
||||
if re.match(r'^-?\d+(\.\d+)?$', input_str):
|
||||
value = float(input_str)
|
||||
return [value] * length
|
||||
|
||||
# Case: Explicit list (e.g., "[0,0,1,1,0.9,0.8,0.6]")
|
||||
if input_str.startswith("[") and input_str.endswith("]"):
|
||||
if input_str[1:-1].strip() == "":
|
||||
return [default] * length
|
||||
|
||||
# Use regex to properly split on commas while handling negative numbers
|
||||
values = [float(x) for x in re.findall(r'-?\d+(?:\.\d+)?', input_str)]
|
||||
# If list is shorter than required length, repeat the pattern
|
||||
if len(values) < length:
|
||||
values = (values * (length // len(values) + 1))[:length]
|
||||
# If list is longer than required length, truncate
|
||||
return values[:length]
|
||||
|
||||
# Pre-process to handle function parameters with commas
|
||||
# Replace function parameters with placeholders
|
||||
function_params = {}
|
||||
placeholder_counter = 0
|
||||
|
||||
def replace_function(match):
|
||||
nonlocal placeholder_counter
|
||||
func_with_params = match.group(0)
|
||||
placeholder = f"FUNC_PLACEHOLDER_{placeholder_counter}"
|
||||
function_params[placeholder] = func_with_params
|
||||
placeholder_counter += 1
|
||||
return placeholder
|
||||
|
||||
# Find function calls with parameters and replace them
|
||||
preprocessed_str = re.sub(r'\w+\([^)]+\)', replace_function, input_str)
|
||||
|
||||
# Case: Default value with specific overrides (e.g., "1.0,0:0.5")
|
||||
parts = preprocessed_str.split(',')
|
||||
default_value = default
|
||||
|
||||
# Check if the first part is a default value (no colon)
|
||||
if ':' not in parts[0] and re.match(r'^-?\d+(\.\d+)?$', parts[0]):
|
||||
default_value = float(parts[0])
|
||||
parts = parts[1:] # Remove the default value from parts
|
||||
# Fill the result with the default value
|
||||
result = [default_value] * length
|
||||
|
||||
# Process the remaining parts as ranges or single indices
|
||||
for part in parts:
|
||||
if ':' not in part:
|
||||
continue # Skip parts without colon (should only be the default value)
|
||||
|
||||
indices_part, value_part = part.split(':')
|
||||
|
||||
# Restore any function placeholders
|
||||
for placeholder, original in function_params.items():
|
||||
if placeholder in value_part:
|
||||
value_part = value_part.replace(placeholder, original)
|
||||
|
||||
# Handle range (e.g., "10-18" or "-5-10")
|
||||
if '-' in indices_part and not indices_part.startswith('-'):
|
||||
# This is a range with a dash (not just a negative number)
|
||||
range_parts = indices_part.split('-', 1) # Split on first dash only
|
||||
|
||||
# Handle potential negative values in the range
|
||||
if range_parts[0] == '':
|
||||
# Handle case like "-5-10" (from -5 to 10)
|
||||
start_idx = int('-' + range_parts[1].split('-')[0])
|
||||
end_idx = int(range_parts[1].split('-')[1])
|
||||
else:
|
||||
# Normal case like "5-10" or "-5-(-3)"
|
||||
start_idx = int(range_parts[0])
|
||||
end_idx_str = range_parts[1]
|
||||
|
||||
# Handle potentially complex end index expressions
|
||||
if end_idx_str.startswith('(') and end_idx_str.endswith(')'):
|
||||
# Handle expressions like "(-3)"
|
||||
end_idx = eval(end_idx_str)
|
||||
else:
|
||||
print("end_idx_str", end_idx_str)
|
||||
# If end str is blank, set to start idx
|
||||
if end_idx_str == "":
|
||||
warnings.warn("Range end was missing, setting to start of range")
|
||||
end_idx = start_idx
|
||||
else:
|
||||
end_idx = int(end_idx_str)
|
||||
|
||||
# Make sure indices are within bounds
|
||||
start_idx = max(0, min(start_idx, length-1))
|
||||
end_idx = max(0, min(end_idx, length-1))
|
||||
range_length = end_idx - start_idx + 1
|
||||
|
||||
# Check if we have a function with parameters
|
||||
# Checking function and 2 numbers (float and int) separated by ,
|
||||
# cos(0.2, 0.8), cos(0, 1.0), cos(1, 0.1)
|
||||
func_match = re.match(r'(\w+)\((\d+|\d+\.\d+),(\d+|\d+\.\d+)\)', value_part)
|
||||
if func_match:
|
||||
func_name = func_match.group(1)
|
||||
start_val = float(func_match.group(2))
|
||||
end_val = float(func_match.group(3))
|
||||
|
||||
if func_name == 'cos':
|
||||
# Implement parameterized cosine
|
||||
for i in range(range_length):
|
||||
# Calculate position in the range from 0 to π (half a period)
|
||||
position = i / (range_length - 1) * math.pi if range_length > 1 else 0
|
||||
# Cosine from 1 at 0 to 0 at π, scaled to requested range
|
||||
normalized_value = (1 + math.cos(position)) / 2
|
||||
# Scale and shift to the requested start and end values
|
||||
value = start_val + normalized_value * (end_val - start_val)
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
elif func_name == 'sin':
|
||||
# Implement parameterized sine
|
||||
for i in range(range_length):
|
||||
# Calculate position in the range from 0 to π/2 (quarter period)
|
||||
position = i / (range_length - 1) * (math.pi/2) if range_length > 1 else 0
|
||||
# Sine from 0 at 0 to 1 at π/2, scaled to requested range
|
||||
normalized_value = math.sin(position)
|
||||
# Scale and shift to the requested start and end values
|
||||
value = start_val + normalized_value * (end_val - start_val)
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
elif func_name == 'linear':
|
||||
# Implement parameterized linear function
|
||||
for i in range(range_length):
|
||||
# Linear interpolation from start_val to end_val
|
||||
t = i / (range_length - 1) if range_length > 1 else 0
|
||||
value = start_val + t * (end_val - start_val)
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
elif func_name == 'reverse_linear':
|
||||
# Implement parameterized reverse linear function
|
||||
for i in range(range_length):
|
||||
# Linear interpolation from end_val to start_val
|
||||
t = i / (range_length - 1) if range_length > 1 else 0
|
||||
value = end_val + t * (start_val - end_val)
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
# Handle non-parameterized functions
|
||||
elif value_part == 'cos':
|
||||
# Default cosine from 1 to 0
|
||||
for i in range(range_length):
|
||||
position = i / (range_length - 1) * math.pi if range_length > 1 else 0
|
||||
value = (1 + math.cos(position)) / 2
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
elif value_part == 'sin':
|
||||
# Default sine from 0 to 1
|
||||
for i in range(range_length):
|
||||
position = i / (range_length - 1) * (math.pi/2) if range_length > 1 else 0
|
||||
value = math.sin(position)
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
elif value_part == 'linear':
|
||||
# Default linear from 0 to 1
|
||||
for i in range(range_length):
|
||||
value = i / (range_length - 1) if range_length > 1 else 0
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
elif value_part == 'reverse_linear':
|
||||
# Default reverse linear from 1 to 0
|
||||
for i in range(range_length):
|
||||
value = 1 - (i / (range_length - 1) if range_length > 1 else 0)
|
||||
if start_idx + i < length:
|
||||
result[start_idx + i] = value
|
||||
|
||||
else:
|
||||
# Regular numeric value
|
||||
try:
|
||||
value = float(value_part)
|
||||
for i in range(start_idx, end_idx + 1):
|
||||
if 0 <= i < length:
|
||||
result[i] = value
|
||||
except ValueError:
|
||||
warnings.warn(f"Could not parse value '{value_part}'")
|
||||
|
||||
# Handle single index (e.g., "1")
|
||||
else:
|
||||
try:
|
||||
index = int(indices_part)
|
||||
if 0 <= index < length:
|
||||
# Check if we have a function with parameters (unlikely for single index)
|
||||
if '(' in value_part and ')' in value_part:
|
||||
warnings.warn("Functions with parameters not supported for single indices: {part}")
|
||||
continue
|
||||
|
||||
# Assuming a single index won't have a function pattern, just a value
|
||||
value = float(value_part)
|
||||
result[index] = value
|
||||
except ValueError:
|
||||
raise RuntimeError(f"Could not parse index '{indices_part}'")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
303
tests/library/test_lora_util_blocks.py
Normal file
303
tests/library/test_lora_util_blocks.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import pytest
|
||||
import math
|
||||
from library.lora_util import parse_blocks
|
||||
|
||||
|
||||
def test_single_value():
|
||||
# Test single numeric value
|
||||
result = parse_blocks("1.0")
|
||||
assert len(result) == 19
|
||||
assert all(val == 1.0 for val in result), "set all values to 1.0 when default value is 1.0"
|
||||
|
||||
# Test zero
|
||||
result = parse_blocks("0")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "set all values to 0.0 when default value is 0"
|
||||
|
||||
# Test negative value
|
||||
result = parse_blocks("-0.5")
|
||||
assert len(result) == 19
|
||||
assert all(val == -0.5 for val in result), "set all values to -0.5 when default value is -0.5"
|
||||
|
||||
|
||||
def test_explicit_list():
|
||||
# Test exact length list
|
||||
result = parse_blocks("[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,0.9,0.8,0.7,0.6,0.5,0.4,0.3,0.2,0.1]")
|
||||
assert len(result) == 19
|
||||
assert result[0] == 0.1
|
||||
assert result[9] == 1.0
|
||||
assert result[18] == 0.1
|
||||
|
||||
# Test shorter list that repeats
|
||||
result = parse_blocks("[0.0,0.5,1.0]")
|
||||
assert len(result) == 19
|
||||
assert result[0] == 0.0
|
||||
assert result[1] == 0.5
|
||||
assert result[2] == 1.0
|
||||
assert result[3] == 0.0 # Pattern repeats
|
||||
assert result[4] == 0.5
|
||||
|
||||
# Test longer list that gets truncated
|
||||
result = parse_blocks("[" + ",".join(["0.5"] * 25) + "]")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.5 for val in result)
|
||||
|
||||
|
||||
def test_default_with_overrides():
|
||||
# Test default value with single index override
|
||||
result = parse_blocks("1.0,0:0.5")
|
||||
assert len(result) == 19
|
||||
assert result[0] == 0.5
|
||||
assert all(val == 1.0 for val in result[1:])
|
||||
|
||||
# Test default with multiple index overrides
|
||||
result = parse_blocks("0.5,1:0.7,5:0.9,10:0.3")
|
||||
assert len(result) == 19
|
||||
assert result[0] == 0.5 # Default value
|
||||
assert result[1] == 0.7 # Override
|
||||
assert result[5] == 0.9 # Override
|
||||
assert result[10] == 0.3 # Override
|
||||
assert result[18] == 0.5 # Default value
|
||||
|
||||
# Test without default value (should use 0.0)
|
||||
result = parse_blocks("3:0.8")
|
||||
assert len(result) == 19
|
||||
assert result[3] == 0.8
|
||||
assert all(val == 0.0 for i, val in enumerate(result) if i != 3)
|
||||
|
||||
|
||||
def test_range_overrides():
|
||||
# Test simple range
|
||||
result = parse_blocks("1-5:0.7")
|
||||
assert len(result) == 19
|
||||
assert all(result[i] == 0.7 for i in range(1, 6))
|
||||
assert all(val == 0.0 for i, val in enumerate(result) if i < 1 or i > 5)
|
||||
|
||||
# Test multiple ranges
|
||||
result = parse_blocks("0.1,1-3:0.5,7-9:0.8")
|
||||
assert len(result) == 19
|
||||
assert all(result[i] == 0.5 for i in range(1, 4))
|
||||
assert all(result[i] == 0.8 for i in range(7, 10))
|
||||
assert result[0] == 0.1 # Default
|
||||
assert result[6] == 0.1 # Default
|
||||
assert result[18] == 0.1 # Default
|
||||
|
||||
|
||||
def test_cos_function():
|
||||
# Test cos over range
|
||||
result = parse_blocks("1-5:cos")
|
||||
assert len(result) == 19
|
||||
# Calculate expected values for cosine function
|
||||
expected_cos = [(1 + math.cos(i / (5 - 1) * math.pi)) / 2 for i in range(5)]
|
||||
for i in range(1, 6):
|
||||
assert result[i] == pytest.approx(expected_cos[i - 1])
|
||||
|
||||
# Test parameterized cos
|
||||
result = parse_blocks("3-7:cos(0.2,0.8)")
|
||||
assert len(result) == 19
|
||||
# Cos goes from 1 to 0 over π, scaled to range 0.2 to 0.8
|
||||
for i in range(5):
|
||||
normalized = (1 + math.cos(i / (5 - 1) * math.pi)) / 2
|
||||
expected = 0.2 + normalized * (0.8 - 0.2)
|
||||
assert result[i + 3] == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_sin_function():
|
||||
# Test sin over range
|
||||
result = parse_blocks("2-6:sin")
|
||||
assert len(result) == 19
|
||||
# Calculate expected values for sine function
|
||||
expected_sin = [math.sin(i / (6 - 2) * (math.pi / 2)) for i in range(5)]
|
||||
for i in range(2, 7):
|
||||
assert result[i] == pytest.approx(expected_sin[i - 2])
|
||||
|
||||
# Test parameterized sin
|
||||
result = parse_blocks("4-8:sin(0.3,0.9)")
|
||||
assert len(result) == 19
|
||||
# Sin goes from 0 to 1 over π/2, scaled to range 0.3 to 0.9
|
||||
for i in range(5):
|
||||
normalized = math.sin(i / (5 - 1) * (math.pi / 2))
|
||||
expected = 0.3 + normalized * (0.9 - 0.3)
|
||||
assert result[i + 4] == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_linear_function():
|
||||
# Test linear over range
|
||||
result = parse_blocks("3-7:linear")
|
||||
assert len(result) == 19
|
||||
# Calculate expected values for linear function (0 to 1)
|
||||
expected_linear = [i / (7 - 3) for i in range(5)]
|
||||
for i in range(3, 8):
|
||||
assert result[i] == pytest.approx(expected_linear[i - 3])
|
||||
|
||||
# Test parameterized linear
|
||||
result = parse_blocks("5-9:linear(0.4,0.7)")
|
||||
assert len(result) == 19
|
||||
# Linear goes from 0.4 to 0.7
|
||||
for i in range(5):
|
||||
t = i / 4 # normalized position
|
||||
expected = 0.4 + t * (0.7 - 0.4)
|
||||
assert result[i + 5] == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_reverse_linear_function():
|
||||
# Test reverse_linear over range
|
||||
result = parse_blocks("2-6:reverse_linear")
|
||||
assert len(result) == 19
|
||||
# Calculate expected values for reverse linear function (1 to 0)
|
||||
expected_reverse = [1 - i / (6 - 2) for i in range(5)]
|
||||
for i in range(2, 7):
|
||||
assert result[i] == pytest.approx(expected_reverse[i - 2])
|
||||
|
||||
# Test parameterized reverse_linear
|
||||
result = parse_blocks("10-15:reverse_linear(0.8,0.2)")
|
||||
assert len(result) == 19
|
||||
# Reverse linear goes from 0.2 to 0.8 (reversed)
|
||||
for i in range(6):
|
||||
t = i / 5 # normalized position
|
||||
expected = 0.2 + t * (0.8 - 0.2)
|
||||
assert result[i + 10] == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_custom_length():
|
||||
# Test with custom length
|
||||
result = parse_blocks("1.0", length=5)
|
||||
assert len(result) == 5
|
||||
assert all(val == 1.0 for val in result)
|
||||
|
||||
# Test list with custom length
|
||||
result = parse_blocks("[0.1,0.2,0.3]", length=10)
|
||||
assert len(result) == 10
|
||||
assert result[0] == 0.1
|
||||
assert result[3] == 0.1 # Pattern repeats
|
||||
|
||||
# Test ranges with custom length
|
||||
result = parse_blocks("1-3:0.5", length=7)
|
||||
assert len(result) == 7
|
||||
assert all(result[i] == 0.5 for i in range(1, 4))
|
||||
assert result[0] == 0.0
|
||||
assert result[6] == 0.0
|
||||
|
||||
|
||||
def test_custom_default():
|
||||
# Test with custom default value
|
||||
result = parse_blocks("1:0.5", default=0.2)
|
||||
assert len(result) == 19
|
||||
assert result[1] == 0.5
|
||||
assert result[0] == 0.2
|
||||
assert result[18] == 0.2
|
||||
|
||||
# Test overriding default value
|
||||
result = parse_blocks("0.7,1:0.5", default=0.2)
|
||||
assert len(result) == 19
|
||||
assert result[1] == 0.5
|
||||
assert result[0] == 0.7 # Explicitly set default
|
||||
assert result[18] == 0.7
|
||||
|
||||
|
||||
def test_out_of_bounds_indices():
|
||||
# Test negative indices (should be ignored)
|
||||
result = parse_blocks("-5:0.9")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "Negative index should be ignored"
|
||||
|
||||
# Test indices beyond length
|
||||
result = parse_blocks("25:0.8")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "Indices above the max length should be ignored"
|
||||
|
||||
# Test range partially out of bounds
|
||||
result = parse_blocks("17-22:0.7")
|
||||
assert len(result) == 19
|
||||
assert result[17] == 0.7
|
||||
assert result[18] == 0.7
|
||||
# Indices 19-22 would be out of bounds
|
||||
|
||||
|
||||
def test_mixed_patterns():
|
||||
# Test combining different formats
|
||||
result = parse_blocks("0.3,2:0.8,5-8:cos,10-15:linear(0.1,0.9)")
|
||||
assert len(result) == 19
|
||||
assert result[0] == 0.3 # Default
|
||||
assert result[2] == 0.8 # Single index
|
||||
|
||||
# Check cos values
|
||||
cos_range = range(5, 9)
|
||||
expected_cos = [(1 + math.cos(i / (8 - 5) * math.pi)) / 2 for i in range(4)]
|
||||
for i, idx in enumerate(cos_range):
|
||||
assert result[idx] == pytest.approx(expected_cos[i])
|
||||
|
||||
# Check linear values
|
||||
linear_range = range(10, 16)
|
||||
for i, idx in enumerate(linear_range):
|
||||
t = i / 5 # normalized position
|
||||
expected = 0.1 + t * (0.9 - 0.1)
|
||||
assert result[idx] == pytest.approx(expected)
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
# Test empty string
|
||||
result = parse_blocks("")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result)
|
||||
|
||||
# Test whitespace
|
||||
result = parse_blocks(" ")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result)
|
||||
|
||||
# Test empty list
|
||||
result = parse_blocks("[]")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result)
|
||||
|
||||
# Test single-item range
|
||||
result = parse_blocks("5-5:0.7")
|
||||
assert len(result) == 19
|
||||
assert result[5] == 0.7
|
||||
assert result[4] == 0.0
|
||||
assert result[6] == 0.0
|
||||
|
||||
# Test function with single-item range
|
||||
result = parse_blocks("7-7:cos")
|
||||
assert len(result) == 19
|
||||
assert result[7] == 1.0 # When range is single point, cos at position 0 is 1
|
||||
|
||||
# Test overlapping ranges
|
||||
result = parse_blocks("1-5:0.3,3-7:0.8")
|
||||
assert len(result) == 19
|
||||
assert result[1] == 0.3
|
||||
assert result[2] == 0.3
|
||||
assert result[3] == 0.8 # Later definition overwrites
|
||||
assert result[4] == 0.8 # Later definition overwrites
|
||||
assert result[5] == 0.8 # Later definition overwrites
|
||||
assert result[7] == 0.8
|
||||
assert result[8] == 0.0
|
||||
|
||||
|
||||
def test_malformed_input():
|
||||
# Test malformed list
|
||||
result = parse_blocks("[0.1,0.2,")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "malformed list"
|
||||
|
||||
# Test invalid end range
|
||||
result = parse_blocks("5-:0.7")
|
||||
assert len(result) == 19
|
||||
assert result[5] == 0.7
|
||||
assert result[6] == 0.0
|
||||
|
||||
# Test invalid start range, indices should never be negative
|
||||
result = parse_blocks("-5:0.7")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "invalid start range, indices should never be negative"
|
||||
|
||||
# Test invalid function
|
||||
result = parse_blocks("1-5:unknown_func")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "Function name not recognized"
|
||||
|
||||
# Test invalid function parameters
|
||||
result = parse_blocks("1-5:cos(invalid,0.8)")
|
||||
assert len(result) == 19
|
||||
assert all(val == 0.0 for val in result), "Invalid parameters"
|
||||
Reference in New Issue
Block a user