mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
242 lines
9.4 KiB
Python
242 lines
9.4 KiB
Python
import os
|
|
import tempfile
|
|
import torch
|
|
import numpy as np
|
|
from unittest.mock import patch
|
|
from transformers import Gemma2Model
|
|
|
|
from library.strategy_lumina import (
|
|
LuminaTokenizeStrategy,
|
|
LuminaTextEncodingStrategy,
|
|
LuminaTextEncoderOutputsCachingStrategy,
|
|
LuminaLatentsCachingStrategy,
|
|
)
|
|
|
|
|
|
class SimpleMockGemma2Model:
|
|
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
|
|
|
|
def __init__(self, hidden_size=2304):
|
|
self.device = torch.device("cpu")
|
|
self._hidden_size = hidden_size
|
|
self._orig_mod = self # For dynamic compilation compatibility
|
|
|
|
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
|
|
# Create a mock output object with hidden states
|
|
batch_size, seq_len = input_ids.shape
|
|
hidden_size = self._hidden_size
|
|
|
|
class MockOutput:
|
|
def __init__(self, hidden_states):
|
|
self.hidden_states = hidden_states
|
|
|
|
mock_hidden_states = [
|
|
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
|
|
for _ in range(3) # Mimic multiple layers of hidden states
|
|
]
|
|
|
|
return MockOutput(mock_hidden_states)
|
|
|
|
|
|
def test_lumina_tokenize_strategy():
|
|
# Test default initialization
|
|
try:
|
|
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
|
except OSError as e:
|
|
# If the tokenizer is not found (due to gated repo), we can skip the test
|
|
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
|
return
|
|
assert tokenize_strategy.max_length == 256
|
|
assert tokenize_strategy.tokenizer.padding_side == "right"
|
|
|
|
# Test tokenization of a single string
|
|
text = "Hello"
|
|
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
|
|
|
assert tokens.ndim == 2
|
|
assert attention_mask.ndim == 2
|
|
assert tokens.shape == attention_mask.shape
|
|
assert tokens.shape[1] == 256 # max_length
|
|
|
|
# Test tokenize_with_weights
|
|
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
|
|
assert len(weights) == 1
|
|
assert torch.all(weights[0] == 1)
|
|
|
|
|
|
def test_lumina_text_encoding_strategy():
|
|
# Create strategies
|
|
try:
|
|
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
|
except OSError as e:
|
|
# If the tokenizer is not found (due to gated repo), we can skip the test
|
|
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
|
return
|
|
encoding_strategy = LuminaTextEncodingStrategy()
|
|
|
|
# Create a mock model
|
|
mock_model = SimpleMockGemma2Model()
|
|
|
|
# Patch the isinstance check to accept our simple mock
|
|
original_isinstance = isinstance
|
|
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
|
|
|
def custom_isinstance(obj, class_or_tuple):
|
|
if obj is mock_model and class_or_tuple is Gemma2Model:
|
|
return True
|
|
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
|
return True
|
|
return original_isinstance(obj, class_or_tuple)
|
|
|
|
mock_isinstance.side_effect = custom_isinstance
|
|
|
|
# Prepare sample text
|
|
text = "Test encoding strategy"
|
|
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
|
|
|
# Perform encoding
|
|
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
|
|
tokenize_strategy, [mock_model], (tokens, attention_mask)
|
|
)
|
|
|
|
# Validate outputs
|
|
assert original_isinstance(hidden_states, torch.Tensor)
|
|
assert original_isinstance(input_ids, torch.Tensor)
|
|
assert original_isinstance(attention_masks, torch.Tensor)
|
|
|
|
# Check the shape of the second-to-last hidden state
|
|
assert hidden_states.ndim == 3
|
|
|
|
# Test weighted encoding (which falls back to standard encoding for Lumina)
|
|
weights = [torch.ones_like(tokens)]
|
|
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
|
|
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
|
|
)
|
|
|
|
# For the mock, we can't guarantee identical outputs since each call returns random tensors
|
|
# Instead, check that the outputs have the same shape and are tensors
|
|
assert hidden_states_w.shape == hidden_states.shape
|
|
assert original_isinstance(hidden_states_w, torch.Tensor)
|
|
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
|
|
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
|
|
|
|
|
|
def test_lumina_text_encoder_outputs_caching_strategy():
|
|
# Create a temporary directory for caching
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Create a cache file path
|
|
cache_file = os.path.join(tmpdir, "test_outputs.npz")
|
|
|
|
# Create the caching strategy
|
|
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
|
|
cache_to_disk=True,
|
|
batch_size=1,
|
|
skip_disk_cache_validity_check=False,
|
|
)
|
|
|
|
# Create a mock class for ImageInfo
|
|
class MockImageInfo:
|
|
def __init__(self, caption, cache_path):
|
|
self.caption = caption
|
|
self.text_encoder_outputs_npz = cache_path
|
|
|
|
# Create a sample input info
|
|
image_info = MockImageInfo("Test caption", cache_file)
|
|
|
|
# Simulate a batch
|
|
batch = [image_info]
|
|
|
|
# Create mock strategies and model
|
|
try:
|
|
tokenize_strategy = LuminaTokenizeStrategy("dummy system prompt", max_length=None)
|
|
except OSError as e:
|
|
# If the tokenizer is not found (due to gated repo), we can skip the test
|
|
print(f"Skipping LuminaTokenizeStrategy test due to OSError: {e}")
|
|
return
|
|
encoding_strategy = LuminaTextEncodingStrategy()
|
|
mock_model = SimpleMockGemma2Model()
|
|
|
|
# Patch the isinstance check to accept our simple mock
|
|
original_isinstance = isinstance
|
|
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
|
|
|
def custom_isinstance(obj, class_or_tuple):
|
|
if obj is mock_model and class_or_tuple is Gemma2Model:
|
|
return True
|
|
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
|
return True
|
|
return original_isinstance(obj, class_or_tuple)
|
|
|
|
mock_isinstance.side_effect = custom_isinstance
|
|
|
|
# Call cache_batch_outputs
|
|
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
|
|
|
|
# Verify the npz file was created
|
|
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
|
|
|
|
# Verify the is_disk_cached_outputs_expected method
|
|
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
|
|
|
|
# Test loading from npz
|
|
loaded_data = caching_strategy.load_outputs_npz(cache_file)
|
|
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
|
|
|
|
|
|
def test_lumina_latents_caching_strategy():
|
|
# Create a temporary directory for caching
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
# Prepare a mock absolute path
|
|
abs_path = os.path.join(tmpdir, "test_image.png")
|
|
|
|
# Use smaller image size for faster testing
|
|
image_size = (64, 64)
|
|
|
|
# Create a smaller dummy image for testing
|
|
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
|
|
|
# Create the caching strategy
|
|
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
|
|
|
|
# Create a simple mock VAE
|
|
class MockVAE:
|
|
def __init__(self):
|
|
self.device = torch.device("cpu")
|
|
self.dtype = torch.float32
|
|
|
|
def encode(self, x):
|
|
# Return smaller encoded tensor for faster processing
|
|
encoded = torch.randn(1, 4, 8, 8, device=x.device)
|
|
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
|
|
|
|
# Prepare a mock batch
|
|
class MockImageInfo:
|
|
def __init__(self, path, image):
|
|
self.absolute_path = path
|
|
self.image = image
|
|
self.image_path = path
|
|
self.bucket_reso = image_size
|
|
self.resized_size = image_size
|
|
self.resize_interpolation = "lanczos"
|
|
# Specify full path to the latents npz file
|
|
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
|
|
|
|
batch = [MockImageInfo(abs_path, test_image)]
|
|
|
|
# Call cache_batch_latents
|
|
mock_vae = MockVAE()
|
|
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
|
|
|
|
# Generate the expected npz path
|
|
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
|
|
|
|
# Verify the file was created
|
|
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
|
|
|
|
# Verify is_disk_cached_latents_expected
|
|
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
|
|
|
|
# Test loading from disk
|
|
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
|
|
assert len(loaded_data) == 5 # Check for 5 expected elements
|