Files
Kohya-ss-sd-scripts/tests/library/test_strategy_sd_tokenize.py
2025-10-10 15:07:31 -04:00

379 lines
16 KiB
Python

import pytest
import torch
from unittest.mock import Mock, patch
from library.strategy_sd import SdTokenizeStrategy
class TestSdTokenizeStrategy:
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock CLIP tokenizer."""
tokenizer = Mock()
tokenizer.model_max_length = 77
tokenizer.bos_token_id = 49406
tokenizer.eos_token_id = 49407
tokenizer.pad_token_id = 49407
def tokenize_side_effect(text, **kwargs):
# Simple mock: return incrementing IDs based on text length
# Real tokenizer would split into subwords
num_tokens = min(len(text.split()), 75)
input_ids = torch.arange(1, num_tokens + 1)
if kwargs.get("return_tensors") == "pt":
max_length = kwargs.get("max_length", 77)
padded = torch.cat(
[
torch.tensor([tokenizer.bos_token_id]),
input_ids,
torch.tensor([tokenizer.eos_token_id]),
torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id),
]
)
return Mock(input_ids=padded.unsqueeze(0))
else:
return Mock(
input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])])
)
tokenizer.side_effect = tokenize_side_effect
return tokenizer
@pytest.fixture
def strategy_v1(self, mock_tokenizer):
"""Create a v1 strategy instance with mocked tokenizer."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75, tokenizer_cache_dir=None)
return strategy
@pytest.fixture
def strategy_v2(self, mock_tokenizer):
"""Create a v2 strategy instance with mocked tokenizer."""
mock_tokenizer.pad_token_id = 0 # v2 has different pad token
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=True, max_length=75, tokenizer_cache_dir=None)
return strategy
# Test _split_on_break
def test_split_on_break_no_break(self, strategy_v1):
"""Test splitting when no BREAK is present."""
text = "a cat and a dog"
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == "a cat and a dog"
def test_split_on_break_single_break(self, strategy_v1):
"""Test splitting with single BREAK."""
text = "a cat BREAK a dog"
result = strategy_v1._split_on_break(text)
assert len(result) == 2
assert result[0] == "a cat"
assert result[1] == "a dog"
def test_split_on_break_multiple_breaks(self, strategy_v1):
"""Test splitting with multiple BREAKs."""
text = "a cat BREAK a dog BREAK a bird"
result = strategy_v1._split_on_break(text)
assert len(result) == 3
assert result[0] == "a cat"
assert result[1] == "a dog"
assert result[2] == "a bird"
def test_split_on_break_case_sensitive(self, strategy_v1):
"""Test that BREAK splitting is case-sensitive."""
text = "a cat break a dog" # lowercase 'break' should not split
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == "a cat break a dog"
text = "a cat Break a dog" # mixed case should not split
result = strategy_v1._split_on_break(text)
assert len(result) == 1
def test_split_on_break_with_whitespace(self, strategy_v1):
"""Test splitting with extra whitespace around BREAK."""
text = "a cat BREAK a dog"
result = strategy_v1._split_on_break(text)
assert len(result) == 2
assert result[0] == "a cat"
assert result[1] == "a dog"
def test_split_on_break_empty_segments(self, strategy_v1):
"""Test splitting filters out empty segments."""
text = "BREAK a cat BREAK BREAK a dog BREAK"
result = strategy_v1._split_on_break(text)
assert len(result) == 2
assert result[0] == "a cat"
assert result[1] == "a dog"
def test_split_on_break_only_break(self, strategy_v1):
"""Test splitting with only BREAK returns empty string."""
text = "BREAK"
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == ""
def test_split_on_break_empty_string(self, strategy_v1):
"""Test splitting empty string."""
text = ""
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == ""
# Test tokenize without BREAK
def test_tokenize_single_text_no_break(self, strategy_v1):
"""Test tokenizing single text without BREAK."""
text = "a cat"
result = strategy_v1.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
assert result[0].dim() == 3 # [batch, n_chunks, seq_len]
def test_tokenize_list_no_break(self, strategy_v1):
"""Test tokenizing list of texts without BREAK."""
texts = ["a cat", "a dog"]
result = strategy_v1.tokenize(texts)
assert len(result) == 1
assert result[0].shape[0] == 2 # batch size
# Test tokenize with BREAK
def test_tokenize_single_break(self, strategy_v1):
"""Test tokenizing text with single BREAK."""
text = "a cat BREAK a dog"
result = strategy_v1.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
# Should have concatenated tokens from both segments
def test_tokenize_multiple_breaks(self, strategy_v1):
"""Test tokenizing text with multiple BREAKs."""
text = "a cat BREAK a dog BREAK a bird"
result = strategy_v1.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
def test_tokenize_list_with_breaks(self, strategy_v1):
"""Test tokenizing list where some texts have BREAKs."""
texts = ["a cat BREAK a dog", "a bird"]
result = strategy_v1.tokenize(texts)
assert len(result) == 1
assert result[0].shape[0] == 2 # batch size
# Test tokenize_with_weights without BREAK
def test_tokenize_with_weights_no_break(self, strategy_v1):
"""Test weighted tokenization without BREAK."""
text = "a cat"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert isinstance(tokens_list[0], torch.Tensor)
assert isinstance(weights_list[0], torch.Tensor)
assert tokens_list[0].shape == weights_list[0].shape
def test_tokenize_with_weights_list_no_break(self, strategy_v1):
"""Test weighted tokenization of list without BREAK."""
texts = ["a cat", "a dog"]
tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape[0] == 2 # batch size
assert tokens_list[0].shape == weights_list[0].shape
# Test tokenize_with_weights with BREAK
def test_tokenize_with_weights_single_break(self, strategy_v1):
"""Test weighted tokenization with single BREAK."""
text = "a cat BREAK a dog"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert isinstance(tokens_list[0], torch.Tensor)
assert isinstance(weights_list[0], torch.Tensor)
assert tokens_list[0].shape == weights_list[0].shape
def test_tokenize_with_weights_multiple_breaks(self, strategy_v1):
"""Test weighted tokenization with multiple BREAKs."""
text = "a cat BREAK a dog BREAK a bird"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape == weights_list[0].shape
def test_tokenize_with_weights_list_with_breaks(self, strategy_v1):
"""Test weighted tokenization of list with BREAKs."""
texts = ["a cat BREAK a dog", "a bird BREAK a fish"]
tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape[0] == 2 # batch size
assert tokens_list[0].shape == weights_list[0].shape
# Test weighted prompts (with attention syntax)
def test_tokenize_with_weights_attention_syntax(self, strategy_v1):
"""Test weighted tokenization with attention syntax like (word:1.5)."""
text = "a (cat:1.5) and a dog"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
# Weights should differ from 1.0 for the emphasized word
def test_tokenize_with_weights_attention_and_break(self, strategy_v1):
"""Test weighted tokenization with both attention syntax and BREAK."""
text = "a (cat:1.5) BREAK a [dog:0.8]"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape == weights_list[0].shape
def test_break_splits_long_prompts_into_chunks(self, strategy_v1):
"""Test that BREAK causes long prompts to split into expected number of chunks."""
# Create a prompt with 80 tokens before BREAK and 80 after
# Each "word" typically becomes 1-2 tokens, so ~40-80 words for 80 tokens
long_segment = " ".join([f"word{i}" for i in range(40)]) # ~80 tokens
text = f"{long_segment} BREAK {long_segment}"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
# With model_max_length=77, we expect:
# - First segment: 80 tokens -> needs 2 chunks (77 + remainder)
# - Second segment: 80 tokens -> needs 2 chunks (77 + remainder)
# Total: 4 chunks (2 per segment)
assert len(tokens_list) == 1
assert len(weights_list) == 1
# Check that we got multiple chunks by looking at the shape
# The concatenated result should be longer than a single chunk (77 tokens)
tokens = tokens_list[0]
weights = weights_list[0]
# Should have significantly more than 77 tokens due to concatenation
assert tokens.shape[-1] > 77, f"Expected >77 tokens but got {tokens.shape[-1]}"
# With 2 segments of ~80 tokens each, we expect ~160 total tokens after concatenation
# (exact number depends on tokenizer behavior, but should be in this range)
assert tokens.shape[-1] >= 150, f"Expected >=150 tokens for 2 long segments but got {tokens.shape[-1]}"
def test_break_splits_result_in_proper_chunks(self, strategy_v1):
"""Test that BREAK splitting results in proper chunk structure."""
# Segment 1: ~40 tokens, Segment 2: ~40 tokens
segment1 = " ".join([f"word{i}" for i in range(20)])
segment2 = " ".join([f"word{i}" for i in range(20, 40)])
text = f"{segment1} BREAK {segment2}"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
tokens = tokens_list[0]
weights = weights_list[0]
# Should be concatenated from 2 segments
# Each segment fits in one chunk (< 77 tokens), so total should be ~80 tokens
assert tokens.shape == weights.shape
assert tokens.shape[-1] > 40, "Should have tokens from both segments"
# Test v1 vs v2
def test_v1_vs_v2_initialization(self, mock_tokenizer):
"""Test that v1 and v2 are initialized differently."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy_v1 = SdTokenizeStrategy(v2=False, max_length=75)
strategy_v2 = SdTokenizeStrategy(v2=True, max_length=75)
assert strategy_v1.tokenizer is not None
assert strategy_v2.tokenizer is not None
assert strategy_v1.max_length == 77 # 75 + 2 for BOS/EOS
assert strategy_v2.max_length == 77
# Test max_length handling
def test_max_length_none(self, mock_tokenizer):
"""Test that None max_length uses tokenizer's model_max_length."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=None)
assert strategy.max_length == mock_tokenizer.model_max_length
def test_max_length_custom(self, mock_tokenizer):
"""Test custom max_length."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=150)
assert strategy.max_length == 152 # 150 + 2 for BOS/EOS
class TestEdgeCases:
"""Test edge cases for tokenization."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock CLIP tokenizer."""
tokenizer = Mock()
tokenizer.model_max_length = 77
tokenizer.bos_token_id = 49406
tokenizer.eos_token_id = 49407
tokenizer.pad_token_id = 49407
def tokenize_side_effect(text, **kwargs):
num_tokens = min(len(text.split()), 75)
input_ids = torch.arange(1, num_tokens + 1)
if kwargs.get("return_tensors") == "pt":
max_length = kwargs.get("max_length", 77)
padded = torch.cat(
[
torch.tensor([tokenizer.bos_token_id]),
input_ids,
torch.tensor([tokenizer.eos_token_id]),
torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id),
]
)
return Mock(input_ids=padded.unsqueeze(0))
else:
return Mock(
input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])])
)
tokenizer.side_effect = tokenize_side_effect
return tokenizer
def test_very_long_text_with_breaks(self, mock_tokenizer):
"""Test very long text with multiple BREAKs."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75)
# Create long text segments
long_text = " ".join([f"word{i}" for i in range(50)])
text = f"{long_text} BREAK {long_text} BREAK {long_text}"
result = strategy.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
def test_break_at_boundaries(self, mock_tokenizer):
"""Test BREAK at start and end of text."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75)
# BREAK at start
text = "BREAK a cat"
result = strategy.tokenize(text)
assert len(result) == 1
# BREAK at end
text = "a cat BREAK"
result = strategy.tokenize(text)
assert len(result) == 1
# BREAK at both ends
text = "BREAK a cat BREAK"
result = strategy.tokenize(text)
assert len(result) == 1
def test_consecutive_breaks(self, mock_tokenizer):
"""Test multiple consecutive BREAKs."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75)
text = "a cat BREAK BREAK BREAK a dog"
result = strategy.tokenize(text)
assert len(result) == 1
# Should only create 2 segments (consecutive BREAKs create empty segments that are filtered)
if __name__ == "__main__":
pytest.main([__file__, "-v"])