Files
Kohya-ss-sd-scripts/tests/library/test_strategy_lumina.py
2025-06-29 22:21:48 +09:00

242 lines
9.4 KiB
Python

import os
import tempfile
import torch
import numpy as np
from unittest.mock import patch
from transformers import Gemma2Model
from library.strategy_lumina import (
LuminaTokenizeStrategy,
LuminaTextEncodingStrategy,
LuminaTextEncoderOutputsCachingStrategy,
LuminaLatentsCachingStrategy,
)
class SimpleMockGemma2Model:
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
def __init__(self, hidden_size=2304):
self.device = torch.device("cpu")
self._hidden_size = hidden_size
self._orig_mod = self # For dynamic compilation compatibility
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
# Create a mock output object with hidden states
batch_size, seq_len = input_ids.shape
hidden_size = self._hidden_size
class MockOutput:
def __init__(self, hidden_states):
self.hidden_states = hidden_states
mock_hidden_states = [
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
for _ in range(3) # Mimic multiple layers of hidden states
]
return MockOutput(mock_hidden_states)
def test_lumina_tokenize_strategy():
# Test default initialization
try:
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
return
assert tokenize_strategy.max_length == 256
assert tokenize_strategy.tokenizer.padding_side == "right"
# Test tokenization of a single string
text = "Hello"
tokens, attention_mask = tokenize_strategy.tokenize(text)
assert tokens.ndim == 2
assert attention_mask.ndim == 2
assert tokens.shape == attention_mask.shape
assert tokens.shape[1] == 256 # max_length
# Test tokenize_with_weights
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
assert len(weights) == 1
assert torch.all(weights[0] == 1)
def test_lumina_text_encoding_strategy():
# Create strategies
try:
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
return
encoding_strategy = LuminaTextEncodingStrategy()
# Create a mock model
mock_model = SimpleMockGemma2Model()
# Patch the isinstance check to accept our simple mock
original_isinstance = isinstance
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
def custom_isinstance(obj, class_or_tuple):
if obj is mock_model and class_or_tuple is Gemma2Model:
return True
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
return True
return original_isinstance(obj, class_or_tuple)
mock_isinstance.side_effect = custom_isinstance
# Prepare sample text
text = "Test encoding strategy"
tokens, attention_mask = tokenize_strategy.tokenize(text)
# Perform encoding
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
tokenize_strategy, [mock_model], (tokens, attention_mask)
)
# Validate outputs
assert original_isinstance(hidden_states, torch.Tensor)
assert original_isinstance(input_ids, torch.Tensor)
assert original_isinstance(attention_masks, torch.Tensor)
# Check the shape of the second-to-last hidden state
assert hidden_states.ndim == 3
# Test weighted encoding (which falls back to standard encoding for Lumina)
weights = [torch.ones_like(tokens)]
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
)
# For the mock, we can't guarantee identical outputs since each call returns random tensors
# Instead, check that the outputs have the same shape and are tensors
assert hidden_states_w.shape == hidden_states.shape
assert original_isinstance(hidden_states_w, torch.Tensor)
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
def test_lumina_text_encoder_outputs_caching_strategy():
# Create a temporary directory for caching
with tempfile.TemporaryDirectory() as tmpdir:
# Create a cache file path
cache_file = os.path.join(tmpdir, "test_outputs.npz")
# Create the caching strategy
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=1,
skip_disk_cache_validity_check=False,
)
# Create a mock class for ImageInfo
class MockImageInfo:
def __init__(self, caption, cache_path):
self.caption = caption
self.text_encoder_outputs_npz = cache_path
# Create a sample input info
image_info = MockImageInfo("Test caption", cache_file)
# Simulate a batch
batch = [image_info]
# Create mock strategies and model
try:
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
except OSError as e:
# If the tokenizer is not found (due to gated repo), we can skip the test
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
return
encoding_strategy = LuminaTextEncodingStrategy()
mock_model = SimpleMockGemma2Model()
# Patch the isinstance check to accept our simple mock
original_isinstance = isinstance
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
def custom_isinstance(obj, class_or_tuple):
if obj is mock_model and class_or_tuple is Gemma2Model:
return True
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
return True
return original_isinstance(obj, class_or_tuple)
mock_isinstance.side_effect = custom_isinstance
# Call cache_batch_outputs
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
# Verify the npz file was created
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
# Verify the is_disk_cached_outputs_expected method
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
# Test loading from npz
loaded_data = caching_strategy.load_outputs_npz(cache_file)
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
def test_lumina_latents_caching_strategy():
# Create a temporary directory for caching
with tempfile.TemporaryDirectory() as tmpdir:
# Prepare a mock absolute path
abs_path = os.path.join(tmpdir, "test_image.png")
# Use smaller image size for faster testing
image_size = (64, 64)
# Create a smaller dummy image for testing
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
# Create the caching strategy
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
# Create a simple mock VAE
class MockVAE:
def __init__(self):
self.device = torch.device("cpu")
self.dtype = torch.float32
def encode(self, x):
# Return smaller encoded tensor for faster processing
encoded = torch.randn(1, 4, 8, 8, device=x.device)
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
# Prepare a mock batch
class MockImageInfo:
def __init__(self, path, image):
self.absolute_path = path
self.image = image
self.image_path = path
self.bucket_reso = image_size
self.resized_size = image_size
self.resize_interpolation = "lanczos"
# Specify full path to the latents npz file
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
batch = [MockImageInfo(abs_path, test_image)]
# Call cache_batch_latents
mock_vae = MockVAE()
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
# Generate the expected npz path
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
# Verify the file was created
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
# Verify is_disk_cached_latents_expected
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
# Test loading from disk
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
assert len(loaded_data) == 5 # Check for 5 expected elements