diff --git a/library/strategy_sd.py b/library/strategy_sd.py index 4521ae8d..3ce9f7dc 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -31,81 +31,171 @@ class SdTokenizeStrategy(TokenizeStrategy): ) else: self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir) - if max_length is None: self.max_length = self.tokenizer.model_max_length else: self.max_length = max_length + 2 - + + self.break_separator = "BREAK" + + def _split_on_break(self, text: str) -> List[str]: + """Split text on BREAK separator (case-sensitive), filtering empty segments.""" + segments = text.split(self.break_separator) + # Filter out empty or whitespace-only segments + filtered = [seg.strip() for seg in segments if seg.strip()] + # Return at least one segment to maintain consistency + return filtered if filtered else [""] + + def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Tokenize multiple segments and concatenate them.""" + if len(segments) == 1: + # No BREAK present, use existing logic + if weighted: + return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True) + else: + tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length) + return tokens, None + + # Multiple segments - tokenize each separately + all_tokens = [] + all_weights = [] if weighted else None + + for segment in segments: + if weighted: + seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True) + all_tokens.append(seg_tokens) + all_weights.append(seg_weights) + else: + seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length) + all_tokens.append(seg_tokens) + + # Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len]) + combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0) + combined_weights = None + if weighted: + combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0) + + return combined_tokens, combined_weights + def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]: text = [text] if isinstance(text, str) else text - return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] - + + tokens_list = [] + for t in text: + segments = self._split_on_break(t) + tokens, _ = self._tokenize_segments(segments, weighted=False) + tokens_list.append(tokens) + + # Pad tokens to same length for stacking + max_length = max(t.shape[-1] for t in tokens_list) + padded_tokens = [] + for tokens in tokens_list: + if tokens.shape[-1] < max_length: + # Pad with pad_token_id + pad_size = max_length - tokens.shape[-1] + if tokens.dim() == 2: + padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype) + tokens = torch.cat([tokens, padding], dim=1) + else: + padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype) + tokens = torch.cat([tokens, padding], dim=0) + padded_tokens.append(tokens) + + return [torch.stack(padded_tokens, dim=0)] + def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: text = [text] if isinstance(text, str) else text + tokens_list = [] weights_list = [] for t in text: - tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True) + segments = self._split_on_break(t) + tokens, weights = self._tokenize_segments(segments, weighted=True) tokens_list.append(tokens) weights_list.append(weights) + return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)] class SdTextEncodingStrategy(TextEncodingStrategy): def __init__(self, clip_skip: Optional[int] = None) -> None: self.clip_skip = clip_skip - + + def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor: + """Encode tokens with optional CLIP skip.""" + if self.clip_skip is None: + return text_encoder(tokens)[0] + + enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) + hidden_states = enc_out["hidden_states"][-self.clip_skip] + return text_encoder.text_model.final_layer_norm(hidden_states) + + def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor, + max_token_length: int, model_max_length: int, + tokenizer: Any) -> torch.Tensor: + """Reconstruct embeddings from chunked encoding.""" + v1 = tokenizer.pad_token_id == tokenizer.eos_token_id + states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # + + if not v1: + # v2: ... ... の三連を ... ... へ戻す + for i in range(1, max_token_length, model_max_length): + chunk = encoder_hidden_states[:, i : i + model_max_length - 2] + if i > 0: + for j in range(len(chunk)): + if tokens[j, 1] == tokenizer.eos_token: + chunk[j, 0] = chunk[j, 1] + states_list.append(chunk) + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) + else: + # v1: ... の三連を ... へ戻す + for i in range(1, max_token_length, model_max_length): + states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) + states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) + + return torch.cat(states_list, dim=1) + + def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + """Apply weights for single chunk case (no max_token_length).""" + return encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + + def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor, + weights: torch.Tensor) -> torch.Tensor: + """Apply weights for multi-chunk case (with max_token_length).""" + for i in range(weights.shape[1]): + start_idx = i * 75 + 1 + end_idx = i * 75 + 76 + encoder_hidden_states[:, start_idx:end_idx] = ( + encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1) + ) + return encoder_hidden_states + def encode_tokens( self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor] ) -> List[torch.Tensor]: text_encoder = models[0] tokens = tokens[0] sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy - - # tokens: b,n,77 + b_size = tokens.size()[0] max_token_length = tokens.size()[1] * tokens.size()[2] model_max_length = sd_tokenize_strategy.tokenizer.model_max_length - tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77 - + + tokens = tokens.reshape((-1, model_max_length)) tokens = tokens.to(text_encoder.device) - - if self.clip_skip is None: - encoder_hidden_states = text_encoder(tokens)[0] - else: - enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True) - encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip] - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states) - - # bs*3, 77, 768 or 1024 + + encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens) encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1])) - + if max_token_length != model_max_length: - v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id - if not v1: - # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, max_token_length, model_max_length): - chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # の後から 最後の前まで - if i > 0: - for j in range(len(chunk)): - if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token: - # 空、つまり ...のパターン - chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする - states_list.append(chunk) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # のどちらか - encoder_hidden_states = torch.cat(states_list, dim=1) - else: - # v1: ... の三連を ... へ戻す - states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # - for i in range(1, max_token_length, model_max_length): - states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # の後から の前まで - states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # - encoder_hidden_states = torch.cat(states_list, dim=1) - + encoder_hidden_states = self._reconstruct_embeddings( + encoder_hidden_states, tokens, max_token_length, + model_max_length, sd_tokenize_strategy.tokenizer + ) + return [encoder_hidden_states] - + def encode_tokens_with_weights( self, tokenize_strategy: TokenizeStrategy, @@ -114,23 +204,15 @@ class SdTextEncodingStrategy(TextEncodingStrategy): weights_list: List[torch.Tensor], ) -> List[torch.Tensor]: encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0] - weights = weights_list[0].to(encoder_hidden_states.device) - - # apply weights - if weights.shape[1] == 1: # no max_token_length - # weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768) - encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2) + + if weights.shape[1] == 1: + encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights) else: - # weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768) - for i in range(weights.shape[1]): - encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[ - :, i, 1:-1 - ].unsqueeze(-1) - + encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights) + return [encoder_hidden_states] - class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): # sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix. # and we keep the old npz for the backward compatibility. diff --git a/tests/library/test_strategy_sd_text_encoding.py b/tests/library/test_strategy_sd_text_encoding.py new file mode 100644 index 00000000..a5fedb4c --- /dev/null +++ b/tests/library/test_strategy_sd_text_encoding.py @@ -0,0 +1,140 @@ +import pytest +import torch +from unittest.mock import Mock + +from library.strategy_sd import SdTextEncodingStrategy + + +class TestSdTextEncodingStrategy: + @pytest.fixture + def strategy(self): + """Create strategy instance with default settings.""" + return SdTextEncodingStrategy(clip_skip=None) + + @pytest.fixture + def strategy_with_clip_skip(self): + """Create strategy instance with CLIP skip enabled.""" + return SdTextEncodingStrategy(clip_skip=2) + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock tokenizer.""" + tokenizer = Mock() + tokenizer.model_max_length = 77 + tokenizer.pad_token_id = 0 + tokenizer.eos_token = 2 + tokenizer.eos_token_id = 2 + return tokenizer + + @pytest.fixture + def mock_text_encoder(self): + """Create a mock text encoder.""" + encoder = Mock() + encoder.device = torch.device("cpu") + + def encode_side_effect(tokens, output_hidden_states=False, return_dict=False): + batch_size = tokens.shape[0] + seq_len = tokens.shape[1] + hidden_size = 768 + + # Create deterministic hidden states + hidden_state = torch.ones(batch_size, seq_len, hidden_size) * 0.5 + + if return_dict: + result = { + "hidden_states": [ + hidden_state * 0.8, + hidden_state * 0.9, + hidden_state * 1.0, + ] + } + return result + else: + return [hidden_state] + + encoder.side_effect = encode_side_effect + encoder.text_model = Mock() + encoder.text_model.final_layer_norm = lambda x: x + + return encoder + + @pytest.fixture + def mock_tokenize_strategy(self, mock_tokenizer): + """Create a mock tokenize strategy.""" + strategy = Mock() + strategy.tokenizer = mock_tokenizer + return strategy + + # Test _encode_with_clip_skip + def test_encode_without_clip_skip(self, strategy, mock_text_encoder): + """Test encoding without CLIP skip.""" + tokens = torch.arange(154).reshape(2, 77) + result = strategy._encode_with_clip_skip(mock_text_encoder, tokens) + assert result.shape == (2, 77, 768) + # Verify deterministic output + assert torch.allclose(result[0, 0, 0], torch.tensor(0.5)) + + def test_encode_with_clip_skip(self, strategy_with_clip_skip, mock_text_encoder): + """Test encoding with CLIP skip.""" + tokens = torch.arange(154).reshape(2, 77) + result = strategy_with_clip_skip._encode_with_clip_skip(mock_text_encoder, tokens) + assert result.shape == (2, 77, 768) + # With clip_skip=2, should use second-to-last hidden state (0.5 * 0.9 = 0.45) + assert torch.allclose(result[0, 0, 0], torch.tensor(0.45)) + + # Test _apply_weights_single_chunk + def test_apply_weights_single_chunk(self, strategy): + """Test applying weights for single chunk case.""" + encoder_hidden_states = torch.ones(2, 77, 768) + weights = torch.ones(2, 1, 77) * 0.5 + result = strategy._apply_weights_single_chunk(encoder_hidden_states, weights) + assert result.shape == (2, 77, 768) + # Verify weights were applied: 1.0 * 0.5 = 0.5 + assert torch.allclose(result[0, 0, 0], torch.tensor(0.5)) + + # Test _apply_weights_multi_chunk + def test_apply_weights_multi_chunk(self, strategy): + """Test applying weights for multi-chunk case.""" + # Simulating 2 chunks: 2*75+2 = 152 tokens + encoder_hidden_states = torch.ones(2, 152, 768) + weights = torch.ones(2, 2, 77) * 0.5 + result = strategy._apply_weights_multi_chunk(encoder_hidden_states, weights) + assert result.shape == (2, 152, 768) + # Check that weights were applied to middle sections + assert torch.allclose(result[0, 1, 0], torch.tensor(0.5)) + assert torch.allclose(result[0, 76, 0], torch.tensor(0.5)) + + # Integration tests + def test_encode_tokens_basic(self, strategy, mock_tokenize_strategy, mock_text_encoder): + """Test basic token encoding flow.""" + tokens = torch.arange(154).reshape(2, 1, 77) + models = [mock_text_encoder] + tokens_list = [tokens] + + result = strategy.encode_tokens(mock_tokenize_strategy, models, tokens_list) + + assert len(result) == 1 + assert result[0].shape[0] == 2 # batch size + assert result[0].shape[2] == 768 # hidden size + # Verify deterministic output + assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.5)) + + def test_encode_tokens_with_weights_single_chunk(self, strategy, mock_tokenize_strategy, mock_text_encoder): + """Test weighted encoding with single chunk.""" + tokens = torch.arange(154).reshape(2, 1, 77) + weights = torch.ones(2, 1, 77) * 0.5 + models = [mock_text_encoder] + tokens_list = [tokens] + weights_list = [weights] + + result = strategy.encode_tokens_with_weights(mock_tokenize_strategy, models, tokens_list, weights_list) + + assert len(result) == 1 + assert result[0].shape[0] == 2 + assert result[0].shape[2] == 768 + # Verify weights were applied: 0.5 (encoder output) * 0.5 (weight) = 0.25 + assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.25)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/library/test_strategy_sd_tokenize.py b/tests/library/test_strategy_sd_tokenize.py new file mode 100644 index 00000000..99b0da07 --- /dev/null +++ b/tests/library/test_strategy_sd_tokenize.py @@ -0,0 +1,378 @@ +import pytest +import torch +from unittest.mock import Mock, patch + +from library.strategy_sd import SdTokenizeStrategy + + +class TestSdTokenizeStrategy: + @pytest.fixture + def mock_tokenizer(self): + """Create a mock CLIP tokenizer.""" + tokenizer = Mock() + tokenizer.model_max_length = 77 + tokenizer.bos_token_id = 49406 + tokenizer.eos_token_id = 49407 + tokenizer.pad_token_id = 49407 + + def tokenize_side_effect(text, **kwargs): + # Simple mock: return incrementing IDs based on text length + # Real tokenizer would split into subwords + num_tokens = min(len(text.split()), 75) + input_ids = torch.arange(1, num_tokens + 1) + + if kwargs.get("return_tensors") == "pt": + max_length = kwargs.get("max_length", 77) + padded = torch.cat( + [ + torch.tensor([tokenizer.bos_token_id]), + input_ids, + torch.tensor([tokenizer.eos_token_id]), + torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id), + ] + ) + return Mock(input_ids=padded.unsqueeze(0)) + else: + return Mock( + input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])]) + ) + + tokenizer.side_effect = tokenize_side_effect + return tokenizer + + @pytest.fixture + def strategy_v1(self, mock_tokenizer): + """Create a v1 strategy instance with mocked tokenizer.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=False, max_length=75, tokenizer_cache_dir=None) + return strategy + + @pytest.fixture + def strategy_v2(self, mock_tokenizer): + """Create a v2 strategy instance with mocked tokenizer.""" + mock_tokenizer.pad_token_id = 0 # v2 has different pad token + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=True, max_length=75, tokenizer_cache_dir=None) + return strategy + + # Test _split_on_break + def test_split_on_break_no_break(self, strategy_v1): + """Test splitting when no BREAK is present.""" + text = "a cat and a dog" + result = strategy_v1._split_on_break(text) + assert len(result) == 1 + assert result[0] == "a cat and a dog" + + def test_split_on_break_single_break(self, strategy_v1): + """Test splitting with single BREAK.""" + text = "a cat BREAK a dog" + result = strategy_v1._split_on_break(text) + assert len(result) == 2 + assert result[0] == "a cat" + assert result[1] == "a dog" + + def test_split_on_break_multiple_breaks(self, strategy_v1): + """Test splitting with multiple BREAKs.""" + text = "a cat BREAK a dog BREAK a bird" + result = strategy_v1._split_on_break(text) + assert len(result) == 3 + assert result[0] == "a cat" + assert result[1] == "a dog" + assert result[2] == "a bird" + + def test_split_on_break_case_sensitive(self, strategy_v1): + """Test that BREAK splitting is case-sensitive.""" + text = "a cat break a dog" # lowercase 'break' should not split + result = strategy_v1._split_on_break(text) + assert len(result) == 1 + assert result[0] == "a cat break a dog" + + text = "a cat Break a dog" # mixed case should not split + result = strategy_v1._split_on_break(text) + assert len(result) == 1 + + def test_split_on_break_with_whitespace(self, strategy_v1): + """Test splitting with extra whitespace around BREAK.""" + text = "a cat BREAK a dog" + result = strategy_v1._split_on_break(text) + assert len(result) == 2 + assert result[0] == "a cat" + assert result[1] == "a dog" + + def test_split_on_break_empty_segments(self, strategy_v1): + """Test splitting filters out empty segments.""" + text = "BREAK a cat BREAK BREAK a dog BREAK" + result = strategy_v1._split_on_break(text) + assert len(result) == 2 + assert result[0] == "a cat" + assert result[1] == "a dog" + + def test_split_on_break_only_break(self, strategy_v1): + """Test splitting with only BREAK returns empty string.""" + text = "BREAK" + result = strategy_v1._split_on_break(text) + assert len(result) == 1 + assert result[0] == "" + + def test_split_on_break_empty_string(self, strategy_v1): + """Test splitting empty string.""" + text = "" + result = strategy_v1._split_on_break(text) + assert len(result) == 1 + assert result[0] == "" + + # Test tokenize without BREAK + def test_tokenize_single_text_no_break(self, strategy_v1): + """Test tokenizing single text without BREAK.""" + text = "a cat" + result = strategy_v1.tokenize(text) + assert len(result) == 1 + assert isinstance(result[0], torch.Tensor) + assert result[0].dim() == 3 # [batch, n_chunks, seq_len] + + def test_tokenize_list_no_break(self, strategy_v1): + """Test tokenizing list of texts without BREAK.""" + texts = ["a cat", "a dog"] + result = strategy_v1.tokenize(texts) + assert len(result) == 1 + assert result[0].shape[0] == 2 # batch size + + # Test tokenize with BREAK + def test_tokenize_single_break(self, strategy_v1): + """Test tokenizing text with single BREAK.""" + text = "a cat BREAK a dog" + result = strategy_v1.tokenize(text) + assert len(result) == 1 + assert isinstance(result[0], torch.Tensor) + # Should have concatenated tokens from both segments + + def test_tokenize_multiple_breaks(self, strategy_v1): + """Test tokenizing text with multiple BREAKs.""" + text = "a cat BREAK a dog BREAK a bird" + result = strategy_v1.tokenize(text) + assert len(result) == 1 + assert isinstance(result[0], torch.Tensor) + + def test_tokenize_list_with_breaks(self, strategy_v1): + """Test tokenizing list where some texts have BREAKs.""" + texts = ["a cat BREAK a dog", "a bird"] + result = strategy_v1.tokenize(texts) + assert len(result) == 1 + assert result[0].shape[0] == 2 # batch size + + # Test tokenize_with_weights without BREAK + def test_tokenize_with_weights_no_break(self, strategy_v1): + """Test weighted tokenization without BREAK.""" + text = "a cat" + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + assert isinstance(tokens_list[0], torch.Tensor) + assert isinstance(weights_list[0], torch.Tensor) + assert tokens_list[0].shape == weights_list[0].shape + + def test_tokenize_with_weights_list_no_break(self, strategy_v1): + """Test weighted tokenization of list without BREAK.""" + texts = ["a cat", "a dog"] + tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + assert tokens_list[0].shape[0] == 2 # batch size + assert tokens_list[0].shape == weights_list[0].shape + + # Test tokenize_with_weights with BREAK + def test_tokenize_with_weights_single_break(self, strategy_v1): + """Test weighted tokenization with single BREAK.""" + text = "a cat BREAK a dog" + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + assert isinstance(tokens_list[0], torch.Tensor) + assert isinstance(weights_list[0], torch.Tensor) + assert tokens_list[0].shape == weights_list[0].shape + + def test_tokenize_with_weights_multiple_breaks(self, strategy_v1): + """Test weighted tokenization with multiple BREAKs.""" + text = "a cat BREAK a dog BREAK a bird" + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + assert tokens_list[0].shape == weights_list[0].shape + + def test_tokenize_with_weights_list_with_breaks(self, strategy_v1): + """Test weighted tokenization of list with BREAKs.""" + texts = ["a cat BREAK a dog", "a bird BREAK a fish"] + tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + assert tokens_list[0].shape[0] == 2 # batch size + assert tokens_list[0].shape == weights_list[0].shape + + # Test weighted prompts (with attention syntax) + def test_tokenize_with_weights_attention_syntax(self, strategy_v1): + """Test weighted tokenization with attention syntax like (word:1.5).""" + text = "a (cat:1.5) and a dog" + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + # Weights should differ from 1.0 for the emphasized word + + def test_tokenize_with_weights_attention_and_break(self, strategy_v1): + """Test weighted tokenization with both attention syntax and BREAK.""" + text = "a (cat:1.5) BREAK a [dog:0.8]" + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + assert tokens_list[0].shape == weights_list[0].shape + + def test_break_splits_long_prompts_into_chunks(self, strategy_v1): + """Test that BREAK causes long prompts to split into expected number of chunks.""" + # Create a prompt with 80 tokens before BREAK and 80 after + # Each "word" typically becomes 1-2 tokens, so ~40-80 words for 80 tokens + long_segment = " ".join([f"word{i}" for i in range(40)]) # ~80 tokens + text = f"{long_segment} BREAK {long_segment}" + + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + + # With model_max_length=77, we expect: + # - First segment: 80 tokens -> needs 2 chunks (77 + remainder) + # - Second segment: 80 tokens -> needs 2 chunks (77 + remainder) + # Total: 4 chunks (2 per segment) + + assert len(tokens_list) == 1 + assert len(weights_list) == 1 + + # Check that we got multiple chunks by looking at the shape + # The concatenated result should be longer than a single chunk (77 tokens) + tokens = tokens_list[0] + weights = weights_list[0] + + # Should have significantly more than 77 tokens due to concatenation + assert tokens.shape[-1] > 77, f"Expected >77 tokens but got {tokens.shape[-1]}" + + # With 2 segments of ~80 tokens each, we expect ~160 total tokens after concatenation + # (exact number depends on tokenizer behavior, but should be in this range) + assert tokens.shape[-1] >= 150, f"Expected >=150 tokens for 2 long segments but got {tokens.shape[-1]}" + + def test_break_splits_result_in_proper_chunks(self, strategy_v1): + """Test that BREAK splitting results in proper chunk structure.""" + # Segment 1: ~40 tokens, Segment 2: ~40 tokens + segment1 = " ".join([f"word{i}" for i in range(20)]) + segment2 = " ".join([f"word{i}" for i in range(20, 40)]) + text = f"{segment1} BREAK {segment2}" + + tokens_list, weights_list = strategy_v1.tokenize_with_weights(text) + + tokens = tokens_list[0] + weights = weights_list[0] + + # Should be concatenated from 2 segments + # Each segment fits in one chunk (< 77 tokens), so total should be ~80 tokens + assert tokens.shape == weights.shape + assert tokens.shape[-1] > 40, "Should have tokens from both segments" + + # Test v1 vs v2 + def test_v1_vs_v2_initialization(self, mock_tokenizer): + """Test that v1 and v2 are initialized differently.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy_v1 = SdTokenizeStrategy(v2=False, max_length=75) + strategy_v2 = SdTokenizeStrategy(v2=True, max_length=75) + + assert strategy_v1.tokenizer is not None + assert strategy_v2.tokenizer is not None + assert strategy_v1.max_length == 77 # 75 + 2 for BOS/EOS + assert strategy_v2.max_length == 77 + + # Test max_length handling + def test_max_length_none(self, mock_tokenizer): + """Test that None max_length uses tokenizer's model_max_length.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=False, max_length=None) + assert strategy.max_length == mock_tokenizer.model_max_length + + def test_max_length_custom(self, mock_tokenizer): + """Test custom max_length.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=False, max_length=150) + assert strategy.max_length == 152 # 150 + 2 for BOS/EOS + + +class TestEdgeCases: + """Test edge cases for tokenization.""" + + @pytest.fixture + def mock_tokenizer(self): + """Create a mock CLIP tokenizer.""" + tokenizer = Mock() + tokenizer.model_max_length = 77 + tokenizer.bos_token_id = 49406 + tokenizer.eos_token_id = 49407 + tokenizer.pad_token_id = 49407 + + def tokenize_side_effect(text, **kwargs): + num_tokens = min(len(text.split()), 75) + input_ids = torch.arange(1, num_tokens + 1) + + if kwargs.get("return_tensors") == "pt": + max_length = kwargs.get("max_length", 77) + padded = torch.cat( + [ + torch.tensor([tokenizer.bos_token_id]), + input_ids, + torch.tensor([tokenizer.eos_token_id]), + torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id), + ] + ) + return Mock(input_ids=padded.unsqueeze(0)) + else: + return Mock( + input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])]) + ) + + tokenizer.side_effect = tokenize_side_effect + return tokenizer + + def test_very_long_text_with_breaks(self, mock_tokenizer): + """Test very long text with multiple BREAKs.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=False, max_length=75) + # Create long text segments + long_text = " ".join([f"word{i}" for i in range(50)]) + text = f"{long_text} BREAK {long_text} BREAK {long_text}" + + result = strategy.tokenize(text) + assert len(result) == 1 + assert isinstance(result[0], torch.Tensor) + + def test_break_at_boundaries(self, mock_tokenizer): + """Test BREAK at start and end of text.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=False, max_length=75) + + # BREAK at start + text = "BREAK a cat" + result = strategy.tokenize(text) + assert len(result) == 1 + + # BREAK at end + text = "a cat BREAK" + result = strategy.tokenize(text) + assert len(result) == 1 + + # BREAK at both ends + text = "BREAK a cat BREAK" + result = strategy.tokenize(text) + assert len(result) == 1 + + def test_consecutive_breaks(self, mock_tokenizer): + """Test multiple consecutive BREAKs.""" + with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer): + strategy = SdTokenizeStrategy(v2=False, max_length=75) + text = "a cat BREAK BREAK BREAK a dog" + result = strategy.tokenize(text) + assert len(result) == 1 + # Should only create 2 segments (consecutive BREAKs create empty segments that are filtered) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])