mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
141 lines
5.3 KiB
Python
141 lines
5.3 KiB
Python
import pytest
|
|
import torch
|
|
from unittest.mock import Mock
|
|
|
|
from library.strategy_sd import SdTextEncodingStrategy
|
|
|
|
|
|
class TestSdTextEncodingStrategy:
|
|
@pytest.fixture
|
|
def strategy(self):
|
|
"""Create strategy instance with default settings."""
|
|
return SdTextEncodingStrategy(clip_skip=None)
|
|
|
|
@pytest.fixture
|
|
def strategy_with_clip_skip(self):
|
|
"""Create strategy instance with CLIP skip enabled."""
|
|
return SdTextEncodingStrategy(clip_skip=2)
|
|
|
|
@pytest.fixture
|
|
def mock_tokenizer(self):
|
|
"""Create a mock tokenizer."""
|
|
tokenizer = Mock()
|
|
tokenizer.model_max_length = 77
|
|
tokenizer.pad_token_id = 0
|
|
tokenizer.eos_token = 2
|
|
tokenizer.eos_token_id = 2
|
|
return tokenizer
|
|
|
|
@pytest.fixture
|
|
def mock_text_encoder(self):
|
|
"""Create a mock text encoder."""
|
|
encoder = Mock()
|
|
encoder.device = torch.device("cpu")
|
|
|
|
def encode_side_effect(tokens, output_hidden_states=False, return_dict=False):
|
|
batch_size = tokens.shape[0]
|
|
seq_len = tokens.shape[1]
|
|
hidden_size = 768
|
|
|
|
# Create deterministic hidden states
|
|
hidden_state = torch.ones(batch_size, seq_len, hidden_size) * 0.5
|
|
|
|
if return_dict:
|
|
result = {
|
|
"hidden_states": [
|
|
hidden_state * 0.8,
|
|
hidden_state * 0.9,
|
|
hidden_state * 1.0,
|
|
]
|
|
}
|
|
return result
|
|
else:
|
|
return [hidden_state]
|
|
|
|
encoder.side_effect = encode_side_effect
|
|
encoder.text_model = Mock()
|
|
encoder.text_model.final_layer_norm = lambda x: x
|
|
|
|
return encoder
|
|
|
|
@pytest.fixture
|
|
def mock_tokenize_strategy(self, mock_tokenizer):
|
|
"""Create a mock tokenize strategy."""
|
|
strategy = Mock()
|
|
strategy.tokenizer = mock_tokenizer
|
|
return strategy
|
|
|
|
# Test _encode_with_clip_skip
|
|
def test_encode_without_clip_skip(self, strategy, mock_text_encoder):
|
|
"""Test encoding without CLIP skip."""
|
|
tokens = torch.arange(154).reshape(2, 77)
|
|
result = strategy._encode_with_clip_skip(mock_text_encoder, tokens)
|
|
assert result.shape == (2, 77, 768)
|
|
# Verify deterministic output
|
|
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))
|
|
|
|
def test_encode_with_clip_skip(self, strategy_with_clip_skip, mock_text_encoder):
|
|
"""Test encoding with CLIP skip."""
|
|
tokens = torch.arange(154).reshape(2, 77)
|
|
result = strategy_with_clip_skip._encode_with_clip_skip(mock_text_encoder, tokens)
|
|
assert result.shape == (2, 77, 768)
|
|
# With clip_skip=2, should use second-to-last hidden state (0.5 * 0.9 = 0.45)
|
|
assert torch.allclose(result[0, 0, 0], torch.tensor(0.45))
|
|
|
|
# Test _apply_weights_single_chunk
|
|
def test_apply_weights_single_chunk(self, strategy):
|
|
"""Test applying weights for single chunk case."""
|
|
encoder_hidden_states = torch.ones(2, 77, 768)
|
|
weights = torch.ones(2, 1, 77) * 0.5
|
|
result = strategy._apply_weights_single_chunk(encoder_hidden_states, weights)
|
|
assert result.shape == (2, 77, 768)
|
|
# Verify weights were applied: 1.0 * 0.5 = 0.5
|
|
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))
|
|
|
|
# Test _apply_weights_multi_chunk
|
|
def test_apply_weights_multi_chunk(self, strategy):
|
|
"""Test applying weights for multi-chunk case."""
|
|
# Simulating 2 chunks: 2*75+2 = 152 tokens
|
|
encoder_hidden_states = torch.ones(2, 152, 768)
|
|
weights = torch.ones(2, 2, 77) * 0.5
|
|
result = strategy._apply_weights_multi_chunk(encoder_hidden_states, weights)
|
|
assert result.shape == (2, 152, 768)
|
|
# Check that weights were applied to middle sections
|
|
assert torch.allclose(result[0, 1, 0], torch.tensor(0.5))
|
|
assert torch.allclose(result[0, 76, 0], torch.tensor(0.5))
|
|
|
|
# Integration tests
|
|
def test_encode_tokens_basic(self, strategy, mock_tokenize_strategy, mock_text_encoder):
|
|
"""Test basic token encoding flow."""
|
|
tokens = torch.arange(154).reshape(2, 1, 77)
|
|
models = [mock_text_encoder]
|
|
tokens_list = [tokens]
|
|
|
|
result = strategy.encode_tokens(mock_tokenize_strategy, models, tokens_list)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].shape[0] == 2 # batch size
|
|
assert result[0].shape[2] == 768 # hidden size
|
|
# Verify deterministic output
|
|
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.5))
|
|
|
|
def test_encode_tokens_with_weights_single_chunk(self, strategy, mock_tokenize_strategy, mock_text_encoder):
|
|
"""Test weighted encoding with single chunk."""
|
|
tokens = torch.arange(154).reshape(2, 1, 77)
|
|
weights = torch.ones(2, 1, 77) * 0.5
|
|
models = [mock_text_encoder]
|
|
tokens_list = [tokens]
|
|
weights_list = [weights]
|
|
|
|
result = strategy.encode_tokens_with_weights(mock_tokenize_strategy, models, tokens_list, weights_list)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].shape[0] == 2
|
|
assert result[0].shape[2] == 768
|
|
# Verify weights were applied: 0.5 (encoder output) * 0.5 (weight) = 0.25
|
|
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.25))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|