Compare commits

...

3 Commits

Author SHA1 Message Date
Dave Lage
31fdaeb215 Merge b4b35c34bd into fa53f71ec0 2026-04-04 06:04:10 +09:00
Kohya S.
fa53f71ec0 fix: improve numerical stability by conditionally using float32 in Anima (#2302)
* fix: improve numerical stability by conditionally using float32 in block computations

* doc: update README for improvement stability for fp16 training on Anima in version 0.10.3
2026-04-02 12:36:29 +09:00
rockerBOO
b4b35c34bd Add BREAK for captions in strategy_sd 2025-10-10 15:07:31 -04:00
6 changed files with 674 additions and 60 deletions

View File

@@ -50,6 +50,9 @@ Stable Diffusion等の画像生成モデルの学習、モデルによる画像
### 更新履歴
- **Version 0.10.3 (2026-04-02):**
- Animaでfp16で学習する際の安定性をさらに改善しました。[PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) 問題をご報告いただいた方々に深く感謝します。
- **Version 0.10.2 (2026-03-30):**
- SD/SDXLのLECO学習に対応しました。[PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) および [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294) umisetokikaze氏に深く感謝します。
- 詳細は[ドキュメント](./docs/train_leco.md)をご覧ください。

View File

@@ -47,6 +47,9 @@ If you find this project helpful, please consider supporting its development via
### Change History
- **Version 0.10.3 (2026-04-02):**
- Stability when training with fp16 on Anima has been further improved. See [PR #2302](https://github.com/kohya-ss/sd-scripts/pull/2302) for details. We deeply appreciate those who reported the issue.
- **Version 0.10.2 (2026-03-30):**
- LECO training for SD/SDXL is now supported. Many thanks to umisetokikaze for [PR #2285](https://github.com/kohya-ss/sd-scripts/pull/2285) and [PR #2294](https://github.com/kohya-ss/sd-scripts/pull/2294).
- Please refer to the [documentation](./docs/train_leco.md) for details.

View File

@@ -738,9 +738,9 @@ class FinalLayer(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
use_fp32: bool = False,
):
# Compute AdaLN modulation parameters (in float32 when fp16 to avoid overflow in Linear layers)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
with torch.autocast(device_type=x_B_T_H_W_D.device.type, dtype=torch.float32, enabled=use_fp32):
if self.use_adaln_lora:
assert adaln_lora_B_T_3D is not None
@@ -863,11 +863,11 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
if use_fp32:
# Cast to float32 for better numerical stability in residual connections. Each module will cast back to float16 by enclosing autocast context.
x_B_T_H_W_D = x_B_T_H_W_D.float()
@@ -959,6 +959,7 @@ class Block(nn.Module):
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
use_fp32: bool = False,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -972,6 +973,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -994,6 +996,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1007,6 +1010,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1018,6 +1022,7 @@ class Block(nn.Module):
emb_B_T_D,
crossattn_emb,
attn_params,
use_fp32,
rope_emb_L_1_1_D,
adaln_lora_B_T_3D,
extra_per_block_pos_emb,
@@ -1338,16 +1343,19 @@ class Anima(nn.Module):
attn_params = attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
# Determine whether to use float32 for block computations based on input dtype (use float32 for better stability when input is float16)
use_fp32 = x_B_T_H_W_D.dtype == torch.float16
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, **block_kwargs)
x_B_T_H_W_D = block(x_B_T_H_W_D, t_embedding_B_T_D, crossattn_emb, attn_params, use_fp32, **block_kwargs)
if self.blocks_to_swap:
self.offloader.submit_move_blocks(self.blocks, block_idx)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D, use_fp32=use_fp32)
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)
return x_B_C_Tt_Hp_Wp

View File

@@ -31,81 +31,171 @@ class SdTokenizeStrategy(TokenizeStrategy):
)
else:
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
if max_length is None:
self.max_length = self.tokenizer.model_max_length
else:
self.max_length = max_length + 2
self.break_separator = "BREAK"
def _split_on_break(self, text: str) -> List[str]:
"""Split text on BREAK separator (case-sensitive), filtering empty segments."""
segments = text.split(self.break_separator)
# Filter out empty or whitespace-only segments
filtered = [seg.strip() for seg in segments if seg.strip()]
# Return at least one segment to maintain consistency
return filtered if filtered else [""]
def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Tokenize multiple segments and concatenate them."""
if len(segments) == 1:
# No BREAK present, use existing logic
if weighted:
return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True)
else:
tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length)
return tokens, None
# Multiple segments - tokenize each separately
all_tokens = []
all_weights = [] if weighted else None
for segment in segments:
if weighted:
seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True)
all_tokens.append(seg_tokens)
all_weights.append(seg_weights)
else:
seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length)
all_tokens.append(seg_tokens)
# Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len])
combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0)
combined_weights = None
if weighted:
combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0)
return combined_tokens, combined_weights
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
tokens_list = []
for t in text:
segments = self._split_on_break(t)
tokens, _ = self._tokenize_segments(segments, weighted=False)
tokens_list.append(tokens)
# Pad tokens to same length for stacking
max_length = max(t.shape[-1] for t in tokens_list)
padded_tokens = []
for tokens in tokens_list:
if tokens.shape[-1] < max_length:
# Pad with pad_token_id
pad_size = max_length - tokens.shape[-1]
if tokens.dim() == 2:
padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype)
tokens = torch.cat([tokens, padding], dim=1)
else:
padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype)
tokens = torch.cat([tokens, padding], dim=0)
padded_tokens.append(tokens)
return [torch.stack(padded_tokens, dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []
for t in text:
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
segments = self._split_on_break(t)
tokens, weights = self._tokenize_segments(segments, weighted=True)
tokens_list.append(tokens)
weights_list.append(weights)
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
class SdTextEncodingStrategy(TextEncodingStrategy):
def __init__(self, clip_skip: Optional[int] = None) -> None:
self.clip_skip = clip_skip
def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor:
"""Encode tokens with optional CLIP skip."""
if self.clip_skip is None:
return text_encoder(tokens)[0]
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
hidden_states = enc_out["hidden_states"][-self.clip_skip]
return text_encoder.text_model.final_layer_norm(hidden_states)
def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor,
max_token_length: int, model_max_length: int,
tokenizer: Any) -> torch.Tensor:
"""Reconstruct embeddings from chunked encoding."""
v1 = tokenizer.pad_token_id == tokenizer.eos_token_id
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2]
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == tokenizer.eos_token:
chunk[j, 0] = chunk[j, 1]
states_list.append(chunk)
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2])
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
return torch.cat(states_list, dim=1)
def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
"""Apply weights for single chunk case (no max_token_length)."""
return encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
"""Apply weights for multi-chunk case (with max_token_length)."""
for i in range(weights.shape[1]):
start_idx = i * 75 + 1
end_idx = i * 75 + 76
encoder_hidden_states[:, start_idx:end_idx] = (
encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1)
)
return encoder_hidden_states
def encode_tokens(
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
) -> List[torch.Tensor]:
text_encoder = models[0]
tokens = tokens[0]
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
# tokens: b,n,77
b_size = tokens.size()[0]
max_token_length = tokens.size()[1] * tokens.size()[2]
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
tokens = tokens.reshape((-1, model_max_length))
tokens = tokens.to(text_encoder.device)
if self.clip_skip is None:
encoder_hidden_states = text_encoder(tokens)[0]
else:
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
# bs*3, 77, 768 or 1024
encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens)
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
if max_token_length != model_max_length:
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
if not v1:
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
if i > 0:
for j in range(len(chunk)):
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
encoder_hidden_states = torch.cat(states_list, dim=1)
else:
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, max_token_length, model_max_length):
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
encoder_hidden_states = torch.cat(states_list, dim=1)
encoder_hidden_states = self._reconstruct_embeddings(
encoder_hidden_states, tokens, max_token_length,
model_max_length, sd_tokenize_strategy.tokenizer
)
return [encoder_hidden_states]
def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
@@ -114,23 +204,15 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
weights_list: List[torch.Tensor],
) -> List[torch.Tensor]:
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
weights = weights_list[0].to(encoder_hidden_states.device)
# apply weights
if weights.shape[1] == 1: # no max_token_length
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
if weights.shape[1] == 1:
encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights)
else:
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
for i in range(weights.shape[1]):
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
:, i, 1:-1
].unsqueeze(-1)
encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights)
return [encoder_hidden_states]
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
# and we keep the old npz for the backward compatibility.

View File

@@ -0,0 +1,140 @@
import pytest
import torch
from unittest.mock import Mock
from library.strategy_sd import SdTextEncodingStrategy
class TestSdTextEncodingStrategy:
@pytest.fixture
def strategy(self):
"""Create strategy instance with default settings."""
return SdTextEncodingStrategy(clip_skip=None)
@pytest.fixture
def strategy_with_clip_skip(self):
"""Create strategy instance with CLIP skip enabled."""
return SdTextEncodingStrategy(clip_skip=2)
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.model_max_length = 77
tokenizer.pad_token_id = 0
tokenizer.eos_token = 2
tokenizer.eos_token_id = 2
return tokenizer
@pytest.fixture
def mock_text_encoder(self):
"""Create a mock text encoder."""
encoder = Mock()
encoder.device = torch.device("cpu")
def encode_side_effect(tokens, output_hidden_states=False, return_dict=False):
batch_size = tokens.shape[0]
seq_len = tokens.shape[1]
hidden_size = 768
# Create deterministic hidden states
hidden_state = torch.ones(batch_size, seq_len, hidden_size) * 0.5
if return_dict:
result = {
"hidden_states": [
hidden_state * 0.8,
hidden_state * 0.9,
hidden_state * 1.0,
]
}
return result
else:
return [hidden_state]
encoder.side_effect = encode_side_effect
encoder.text_model = Mock()
encoder.text_model.final_layer_norm = lambda x: x
return encoder
@pytest.fixture
def mock_tokenize_strategy(self, mock_tokenizer):
"""Create a mock tokenize strategy."""
strategy = Mock()
strategy.tokenizer = mock_tokenizer
return strategy
# Test _encode_with_clip_skip
def test_encode_without_clip_skip(self, strategy, mock_text_encoder):
"""Test encoding without CLIP skip."""
tokens = torch.arange(154).reshape(2, 77)
result = strategy._encode_with_clip_skip(mock_text_encoder, tokens)
assert result.shape == (2, 77, 768)
# Verify deterministic output
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))
def test_encode_with_clip_skip(self, strategy_with_clip_skip, mock_text_encoder):
"""Test encoding with CLIP skip."""
tokens = torch.arange(154).reshape(2, 77)
result = strategy_with_clip_skip._encode_with_clip_skip(mock_text_encoder, tokens)
assert result.shape == (2, 77, 768)
# With clip_skip=2, should use second-to-last hidden state (0.5 * 0.9 = 0.45)
assert torch.allclose(result[0, 0, 0], torch.tensor(0.45))
# Test _apply_weights_single_chunk
def test_apply_weights_single_chunk(self, strategy):
"""Test applying weights for single chunk case."""
encoder_hidden_states = torch.ones(2, 77, 768)
weights = torch.ones(2, 1, 77) * 0.5
result = strategy._apply_weights_single_chunk(encoder_hidden_states, weights)
assert result.shape == (2, 77, 768)
# Verify weights were applied: 1.0 * 0.5 = 0.5
assert torch.allclose(result[0, 0, 0], torch.tensor(0.5))
# Test _apply_weights_multi_chunk
def test_apply_weights_multi_chunk(self, strategy):
"""Test applying weights for multi-chunk case."""
# Simulating 2 chunks: 2*75+2 = 152 tokens
encoder_hidden_states = torch.ones(2, 152, 768)
weights = torch.ones(2, 2, 77) * 0.5
result = strategy._apply_weights_multi_chunk(encoder_hidden_states, weights)
assert result.shape == (2, 152, 768)
# Check that weights were applied to middle sections
assert torch.allclose(result[0, 1, 0], torch.tensor(0.5))
assert torch.allclose(result[0, 76, 0], torch.tensor(0.5))
# Integration tests
def test_encode_tokens_basic(self, strategy, mock_tokenize_strategy, mock_text_encoder):
"""Test basic token encoding flow."""
tokens = torch.arange(154).reshape(2, 1, 77)
models = [mock_text_encoder]
tokens_list = [tokens]
result = strategy.encode_tokens(mock_tokenize_strategy, models, tokens_list)
assert len(result) == 1
assert result[0].shape[0] == 2 # batch size
assert result[0].shape[2] == 768 # hidden size
# Verify deterministic output
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.5))
def test_encode_tokens_with_weights_single_chunk(self, strategy, mock_tokenize_strategy, mock_text_encoder):
"""Test weighted encoding with single chunk."""
tokens = torch.arange(154).reshape(2, 1, 77)
weights = torch.ones(2, 1, 77) * 0.5
models = [mock_text_encoder]
tokens_list = [tokens]
weights_list = [weights]
result = strategy.encode_tokens_with_weights(mock_tokenize_strategy, models, tokens_list, weights_list)
assert len(result) == 1
assert result[0].shape[0] == 2
assert result[0].shape[2] == 768
# Verify weights were applied: 0.5 (encoder output) * 0.5 (weight) = 0.25
assert torch.allclose(result[0][0, 0, 0], torch.tensor(0.25))
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,378 @@
import pytest
import torch
from unittest.mock import Mock, patch
from library.strategy_sd import SdTokenizeStrategy
class TestSdTokenizeStrategy:
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock CLIP tokenizer."""
tokenizer = Mock()
tokenizer.model_max_length = 77
tokenizer.bos_token_id = 49406
tokenizer.eos_token_id = 49407
tokenizer.pad_token_id = 49407
def tokenize_side_effect(text, **kwargs):
# Simple mock: return incrementing IDs based on text length
# Real tokenizer would split into subwords
num_tokens = min(len(text.split()), 75)
input_ids = torch.arange(1, num_tokens + 1)
if kwargs.get("return_tensors") == "pt":
max_length = kwargs.get("max_length", 77)
padded = torch.cat(
[
torch.tensor([tokenizer.bos_token_id]),
input_ids,
torch.tensor([tokenizer.eos_token_id]),
torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id),
]
)
return Mock(input_ids=padded.unsqueeze(0))
else:
return Mock(
input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])])
)
tokenizer.side_effect = tokenize_side_effect
return tokenizer
@pytest.fixture
def strategy_v1(self, mock_tokenizer):
"""Create a v1 strategy instance with mocked tokenizer."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75, tokenizer_cache_dir=None)
return strategy
@pytest.fixture
def strategy_v2(self, mock_tokenizer):
"""Create a v2 strategy instance with mocked tokenizer."""
mock_tokenizer.pad_token_id = 0 # v2 has different pad token
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=True, max_length=75, tokenizer_cache_dir=None)
return strategy
# Test _split_on_break
def test_split_on_break_no_break(self, strategy_v1):
"""Test splitting when no BREAK is present."""
text = "a cat and a dog"
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == "a cat and a dog"
def test_split_on_break_single_break(self, strategy_v1):
"""Test splitting with single BREAK."""
text = "a cat BREAK a dog"
result = strategy_v1._split_on_break(text)
assert len(result) == 2
assert result[0] == "a cat"
assert result[1] == "a dog"
def test_split_on_break_multiple_breaks(self, strategy_v1):
"""Test splitting with multiple BREAKs."""
text = "a cat BREAK a dog BREAK a bird"
result = strategy_v1._split_on_break(text)
assert len(result) == 3
assert result[0] == "a cat"
assert result[1] == "a dog"
assert result[2] == "a bird"
def test_split_on_break_case_sensitive(self, strategy_v1):
"""Test that BREAK splitting is case-sensitive."""
text = "a cat break a dog" # lowercase 'break' should not split
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == "a cat break a dog"
text = "a cat Break a dog" # mixed case should not split
result = strategy_v1._split_on_break(text)
assert len(result) == 1
def test_split_on_break_with_whitespace(self, strategy_v1):
"""Test splitting with extra whitespace around BREAK."""
text = "a cat BREAK a dog"
result = strategy_v1._split_on_break(text)
assert len(result) == 2
assert result[0] == "a cat"
assert result[1] == "a dog"
def test_split_on_break_empty_segments(self, strategy_v1):
"""Test splitting filters out empty segments."""
text = "BREAK a cat BREAK BREAK a dog BREAK"
result = strategy_v1._split_on_break(text)
assert len(result) == 2
assert result[0] == "a cat"
assert result[1] == "a dog"
def test_split_on_break_only_break(self, strategy_v1):
"""Test splitting with only BREAK returns empty string."""
text = "BREAK"
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == ""
def test_split_on_break_empty_string(self, strategy_v1):
"""Test splitting empty string."""
text = ""
result = strategy_v1._split_on_break(text)
assert len(result) == 1
assert result[0] == ""
# Test tokenize without BREAK
def test_tokenize_single_text_no_break(self, strategy_v1):
"""Test tokenizing single text without BREAK."""
text = "a cat"
result = strategy_v1.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
assert result[0].dim() == 3 # [batch, n_chunks, seq_len]
def test_tokenize_list_no_break(self, strategy_v1):
"""Test tokenizing list of texts without BREAK."""
texts = ["a cat", "a dog"]
result = strategy_v1.tokenize(texts)
assert len(result) == 1
assert result[0].shape[0] == 2 # batch size
# Test tokenize with BREAK
def test_tokenize_single_break(self, strategy_v1):
"""Test tokenizing text with single BREAK."""
text = "a cat BREAK a dog"
result = strategy_v1.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
# Should have concatenated tokens from both segments
def test_tokenize_multiple_breaks(self, strategy_v1):
"""Test tokenizing text with multiple BREAKs."""
text = "a cat BREAK a dog BREAK a bird"
result = strategy_v1.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
def test_tokenize_list_with_breaks(self, strategy_v1):
"""Test tokenizing list where some texts have BREAKs."""
texts = ["a cat BREAK a dog", "a bird"]
result = strategy_v1.tokenize(texts)
assert len(result) == 1
assert result[0].shape[0] == 2 # batch size
# Test tokenize_with_weights without BREAK
def test_tokenize_with_weights_no_break(self, strategy_v1):
"""Test weighted tokenization without BREAK."""
text = "a cat"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert isinstance(tokens_list[0], torch.Tensor)
assert isinstance(weights_list[0], torch.Tensor)
assert tokens_list[0].shape == weights_list[0].shape
def test_tokenize_with_weights_list_no_break(self, strategy_v1):
"""Test weighted tokenization of list without BREAK."""
texts = ["a cat", "a dog"]
tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape[0] == 2 # batch size
assert tokens_list[0].shape == weights_list[0].shape
# Test tokenize_with_weights with BREAK
def test_tokenize_with_weights_single_break(self, strategy_v1):
"""Test weighted tokenization with single BREAK."""
text = "a cat BREAK a dog"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert isinstance(tokens_list[0], torch.Tensor)
assert isinstance(weights_list[0], torch.Tensor)
assert tokens_list[0].shape == weights_list[0].shape
def test_tokenize_with_weights_multiple_breaks(self, strategy_v1):
"""Test weighted tokenization with multiple BREAKs."""
text = "a cat BREAK a dog BREAK a bird"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape == weights_list[0].shape
def test_tokenize_with_weights_list_with_breaks(self, strategy_v1):
"""Test weighted tokenization of list with BREAKs."""
texts = ["a cat BREAK a dog", "a bird BREAK a fish"]
tokens_list, weights_list = strategy_v1.tokenize_with_weights(texts)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape[0] == 2 # batch size
assert tokens_list[0].shape == weights_list[0].shape
# Test weighted prompts (with attention syntax)
def test_tokenize_with_weights_attention_syntax(self, strategy_v1):
"""Test weighted tokenization with attention syntax like (word:1.5)."""
text = "a (cat:1.5) and a dog"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
# Weights should differ from 1.0 for the emphasized word
def test_tokenize_with_weights_attention_and_break(self, strategy_v1):
"""Test weighted tokenization with both attention syntax and BREAK."""
text = "a (cat:1.5) BREAK a [dog:0.8]"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
assert len(tokens_list) == 1
assert len(weights_list) == 1
assert tokens_list[0].shape == weights_list[0].shape
def test_break_splits_long_prompts_into_chunks(self, strategy_v1):
"""Test that BREAK causes long prompts to split into expected number of chunks."""
# Create a prompt with 80 tokens before BREAK and 80 after
# Each "word" typically becomes 1-2 tokens, so ~40-80 words for 80 tokens
long_segment = " ".join([f"word{i}" for i in range(40)]) # ~80 tokens
text = f"{long_segment} BREAK {long_segment}"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
# With model_max_length=77, we expect:
# - First segment: 80 tokens -> needs 2 chunks (77 + remainder)
# - Second segment: 80 tokens -> needs 2 chunks (77 + remainder)
# Total: 4 chunks (2 per segment)
assert len(tokens_list) == 1
assert len(weights_list) == 1
# Check that we got multiple chunks by looking at the shape
# The concatenated result should be longer than a single chunk (77 tokens)
tokens = tokens_list[0]
weights = weights_list[0]
# Should have significantly more than 77 tokens due to concatenation
assert tokens.shape[-1] > 77, f"Expected >77 tokens but got {tokens.shape[-1]}"
# With 2 segments of ~80 tokens each, we expect ~160 total tokens after concatenation
# (exact number depends on tokenizer behavior, but should be in this range)
assert tokens.shape[-1] >= 150, f"Expected >=150 tokens for 2 long segments but got {tokens.shape[-1]}"
def test_break_splits_result_in_proper_chunks(self, strategy_v1):
"""Test that BREAK splitting results in proper chunk structure."""
# Segment 1: ~40 tokens, Segment 2: ~40 tokens
segment1 = " ".join([f"word{i}" for i in range(20)])
segment2 = " ".join([f"word{i}" for i in range(20, 40)])
text = f"{segment1} BREAK {segment2}"
tokens_list, weights_list = strategy_v1.tokenize_with_weights(text)
tokens = tokens_list[0]
weights = weights_list[0]
# Should be concatenated from 2 segments
# Each segment fits in one chunk (< 77 tokens), so total should be ~80 tokens
assert tokens.shape == weights.shape
assert tokens.shape[-1] > 40, "Should have tokens from both segments"
# Test v1 vs v2
def test_v1_vs_v2_initialization(self, mock_tokenizer):
"""Test that v1 and v2 are initialized differently."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy_v1 = SdTokenizeStrategy(v2=False, max_length=75)
strategy_v2 = SdTokenizeStrategy(v2=True, max_length=75)
assert strategy_v1.tokenizer is not None
assert strategy_v2.tokenizer is not None
assert strategy_v1.max_length == 77 # 75 + 2 for BOS/EOS
assert strategy_v2.max_length == 77
# Test max_length handling
def test_max_length_none(self, mock_tokenizer):
"""Test that None max_length uses tokenizer's model_max_length."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=None)
assert strategy.max_length == mock_tokenizer.model_max_length
def test_max_length_custom(self, mock_tokenizer):
"""Test custom max_length."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=150)
assert strategy.max_length == 152 # 150 + 2 for BOS/EOS
class TestEdgeCases:
"""Test edge cases for tokenization."""
@pytest.fixture
def mock_tokenizer(self):
"""Create a mock CLIP tokenizer."""
tokenizer = Mock()
tokenizer.model_max_length = 77
tokenizer.bos_token_id = 49406
tokenizer.eos_token_id = 49407
tokenizer.pad_token_id = 49407
def tokenize_side_effect(text, **kwargs):
num_tokens = min(len(text.split()), 75)
input_ids = torch.arange(1, num_tokens + 1)
if kwargs.get("return_tensors") == "pt":
max_length = kwargs.get("max_length", 77)
padded = torch.cat(
[
torch.tensor([tokenizer.bos_token_id]),
input_ids,
torch.tensor([tokenizer.eos_token_id]),
torch.full((max_length - num_tokens - 2,), tokenizer.pad_token_id),
]
)
return Mock(input_ids=padded.unsqueeze(0))
else:
return Mock(
input_ids=torch.cat([torch.tensor([tokenizer.bos_token_id]), input_ids, torch.tensor([tokenizer.eos_token_id])])
)
tokenizer.side_effect = tokenize_side_effect
return tokenizer
def test_very_long_text_with_breaks(self, mock_tokenizer):
"""Test very long text with multiple BREAKs."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75)
# Create long text segments
long_text = " ".join([f"word{i}" for i in range(50)])
text = f"{long_text} BREAK {long_text} BREAK {long_text}"
result = strategy.tokenize(text)
assert len(result) == 1
assert isinstance(result[0], torch.Tensor)
def test_break_at_boundaries(self, mock_tokenizer):
"""Test BREAK at start and end of text."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75)
# BREAK at start
text = "BREAK a cat"
result = strategy.tokenize(text)
assert len(result) == 1
# BREAK at end
text = "a cat BREAK"
result = strategy.tokenize(text)
assert len(result) == 1
# BREAK at both ends
text = "BREAK a cat BREAK"
result = strategy.tokenize(text)
assert len(result) == 1
def test_consecutive_breaks(self, mock_tokenizer):
"""Test multiple consecutive BREAKs."""
with patch.object(SdTokenizeStrategy, "_load_tokenizer", return_value=mock_tokenizer):
strategy = SdTokenizeStrategy(v2=False, max_length=75)
text = "a cat BREAK BREAK BREAK a dog"
result = strategy.tokenize(text)
assert len(result) == 1
# Should only create 2 segments (consecutive BREAKs create empty segments that are filtered)
if __name__ == "__main__":
pytest.main([__file__, "-v"])