Add lumina tests and fix image masks

This commit is contained in:
rockerBOO
2025-06-09 21:14:51 -04:00
parent 0145efc2f2
commit d94bed645a
8 changed files with 1129 additions and 267 deletions

View File

@@ -0,0 +1,295 @@
import pytest
import torch
from library.lumina_models import (
LuminaParams,
to_cuda,
to_cpu,
RopeEmbedder,
TimestepEmbedder,
modulate,
NextDiT,
)
cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_lumina_params():
# Test default configuration
default_params = LuminaParams()
assert default_params.patch_size == 2
assert default_params.in_channels == 4
assert default_params.axes_dims == [36, 36, 36]
assert default_params.axes_lens == [300, 512, 512]
# Test 2B config
config_2b = LuminaParams.get_2b_config()
assert config_2b.dim == 2304
assert config_2b.in_channels == 16
assert config_2b.n_layers == 26
assert config_2b.n_heads == 24
assert config_2b.cap_feat_dim == 2304
# Test 7B config
config_7b = LuminaParams.get_7b_config()
assert config_7b.dim == 4096
assert config_7b.n_layers == 32
assert config_7b.n_heads == 32
assert config_7b.axes_dims == [64, 64, 64]
@cuda_required
def test_to_cuda_to_cpu():
# Test tensor conversion
x = torch.tensor([1, 2, 3])
x_cuda = to_cuda(x)
x_cpu = to_cpu(x_cuda)
assert x.cpu().tolist() == x_cpu.tolist()
# Test list conversion
list_data = [torch.tensor([1]), torch.tensor([2])]
list_cuda = to_cuda(list_data)
assert all(tensor.device.type == "cuda" for tensor in list_cuda)
list_cpu = to_cpu(list_cuda)
assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
# Test dict conversion
dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
dict_cuda = to_cuda(dict_data)
assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
dict_cpu = to_cpu(dict_cuda)
assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
def test_timestep_embedder():
# Test initialization
hidden_size = 256
freq_emb_size = 128
embedder = TimestepEmbedder(hidden_size, freq_emb_size)
assert embedder.frequency_embedding_size == freq_emb_size
# Test timestep embedding
t = torch.tensor([0.5, 1.0, 2.0])
emb_dim = freq_emb_size
embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
assert embeddings.shape == (3, emb_dim)
assert embeddings.dtype == torch.float32
# Ensure embeddings are unique for different input times
assert not torch.allclose(embeddings[0], embeddings[1])
# Test forward pass
t_emb = embedder(t)
assert t_emb.shape == (3, hidden_size)
def test_rope_embedder_simple():
rope_embedder = RopeEmbedder()
batch_size, seq_len = 2, 10
# Create position_ids with valid ranges for each axis
position_ids = torch.stack(
[
torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
],
dim=-1,
)
freqs_cis = rope_embedder(position_ids)
# RoPE embeddings work in pairs, so output dimension is half of total axes_dims
expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
def test_modulate():
# Test modulation with different scales
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
scale = torch.tensor([1.5, 2.0])
modulated_x = modulate(x, scale)
# Check that modulation scales correctly
# The function does x * (1 + scale), so:
# For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
# Which equals: [[2.5, 5.0], [9.0, 12.0]]
assert torch.allclose(modulated_x, expected_x)
def test_nextdit_parameter_count_optimized():
# The constraint is: (dim // n_heads) == sum(axes_dims)
# So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
model_small = NextDiT(
patch_size=2,
in_channels=4, # Smaller
dim=120, # 120 // 4 = 30
n_layers=2, # Much fewer layers
n_heads=4, # Fewer heads
n_kv_heads=2,
axes_dims=[10, 10, 10], # sum = 30
axes_lens=[10, 32, 32], # Smaller
)
param_count_small = model_small.parameter_count()
assert param_count_small > 0
# For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
model_medium = NextDiT(
patch_size=2,
in_channels=4,
dim=192, # 192 // 6 = 32
n_layers=4, # More layers
n_heads=6,
n_kv_heads=3,
axes_dims=[10, 11, 11], # sum = 32
axes_lens=[10, 32, 32],
)
param_count_medium = model_medium.parameter_count()
assert param_count_medium > param_count_small
print(f"Small model: {param_count_small:,} parameters")
print(f"Medium model: {param_count_medium:,} parameters")
@torch.no_grad()
def test_precompute_freqs_cis():
# Test precompute_freqs_cis
dim = [16, 56, 56]
end = [1, 512, 512]
theta = 10000.0
freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
# Check number of frequency tensors
assert len(freqs_cis) == len(dim)
# Check each frequency tensor
for i, (d, e) in enumerate(zip(dim, end)):
assert freqs_cis[i].shape == (e, d // 2)
assert freqs_cis[i].dtype == torch.complex128
@torch.no_grad()
def test_nextdit_patchify_and_embed():
"""Test the patchify_and_embed method which is crucial for training"""
# Create a small NextDiT model for testing
# The constraint is: (dim // n_heads) == sum(axes_dims)
# For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
model = NextDiT(
patch_size=2,
in_channels=4,
dim=120, # 120 // 4 = 30
n_layers=1, # Minimal layers for faster testing
n_refiner_layers=1, # Minimal refiner layers
n_heads=4,
n_kv_heads=2,
axes_dims=[10, 10, 10], # sum = 30
axes_lens=[10, 32, 32],
cap_feat_dim=120, # Match dim for consistency
)
# Prepare test inputs
batch_size = 2
height, width = 64, 64 # Must be divisible by patch_size (2)
caption_seq_len = 8
# Create mock inputs
x = torch.randn(batch_size, 4, height, width) # Image latents
cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
# Make second batch have shorter caption
cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
t = torch.randn(batch_size, 120) # Timestep embeddings
# Call patchify_and_embed
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
x, cap_feats, cap_mask, t
)
# Validate outputs
image_seq_len = (height // 2) * (width // 2) # patch_size = 2
expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
max_seq_len = max(expected_seq_lengths)
# Check joint hidden states shape
assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
assert joint_hidden_states.dtype == torch.float32
# Check attention mask shape and values
assert attention_mask.shape == (batch_size, max_seq_len)
assert attention_mask.dtype == torch.bool
# First batch should have all positions valid up to its sequence length
assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
# Second batch should have all positions valid up to its sequence length
assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
# Check freqs_cis shape
assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
# Check effective caption lengths
assert l_effective_cap_len == [caption_seq_len, 6]
# Check sequence lengths
assert seq_lengths == expected_seq_lengths
# Validate that the joint hidden states contain non-zero values where attention mask is True
for i in range(batch_size):
valid_positions = attention_mask[i]
# Check that valid positions have meaningful data (not all zeros)
valid_data = joint_hidden_states[i][valid_positions]
assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
# Check that invalid positions are zeros
if valid_positions.sum() < max_seq_len:
invalid_data = joint_hidden_states[i][~valid_positions]
assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
@torch.no_grad()
def test_nextdit_patchify_and_embed_edge_cases():
"""Test edge cases for patchify_and_embed"""
# Create minimal model
model = NextDiT(
patch_size=2,
in_channels=4,
dim=60, # 60 // 3 = 20
n_layers=1,
n_refiner_layers=1,
n_heads=3,
n_kv_heads=1,
axes_dims=[8, 6, 6], # sum = 20
axes_lens=[10, 16, 16],
cap_feat_dim=60,
)
# Test with empty captions (all masked)
batch_size = 1
height, width = 32, 32
caption_seq_len = 4
x = torch.randn(batch_size, 4, height, width)
cap_feats = torch.randn(batch_size, caption_seq_len, 60)
cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
t = torch.randn(batch_size, 60)
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
x, cap_feats, cap_mask, t
)
# With all captions masked, effective length should be 0
assert l_effective_cap_len == [0]
# Sequence length should just be the image sequence length
image_seq_len = (height // 2) * (width // 2)
assert seq_lengths == [image_seq_len]
# Joint hidden states should only contain image data
assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
assert attention_mask.shape == (batch_size, image_seq_len)
assert torch.all(attention_mask[0]) # All image positions should be valid

View File

@@ -0,0 +1,241 @@
import pytest
import torch
import math
from library.lumina_train_util import (
batchify,
time_shift,
get_lin_function,
get_schedule,
compute_density_for_timestep_sampling,
get_sigmas,
compute_loss_weighting_for_sd3,
get_noisy_model_input_and_timesteps,
apply_model_prediction_type,
retrieve_timesteps,
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
def test_batchify():
# Test case with no batch size specified
prompts = [
{"prompt": "test1"},
{"prompt": "test2"},
{"prompt": "test3"}
]
batchified = list(batchify(prompts))
assert len(batchified) == 1
assert len(batchified[0]) == 3
# Test case with batch size specified
batchified_sized = list(batchify(prompts, batch_size=2))
assert len(batchified_sized) == 2
assert len(batchified_sized[0]) == 2
assert len(batchified_sized[1]) == 1
# Test batching with prompts having same parameters
prompts_with_params = [
{"prompt": "test1", "width": 512, "height": 512},
{"prompt": "test2", "width": 512, "height": 512},
{"prompt": "test3", "width": 1024, "height": 1024}
]
batchified_params = list(batchify(prompts_with_params))
assert len(batchified_params) == 2
# Test invalid batch size
with pytest.raises(ValueError):
list(batchify(prompts, batch_size=0))
with pytest.raises(ValueError):
list(batchify(prompts, batch_size=-1))
def test_time_shift():
# Test standard parameters
t = torch.tensor([0.5])
mu = 1.0
sigma = 1.0
result = time_shift(mu, sigma, t)
assert 0 <= result <= 1
# Test with edge cases
t_edges = torch.tensor([0.0, 1.0])
result_edges = time_shift(1.0, 1.0, t_edges)
# Check that results are bounded within [0, 1]
assert torch.all(result_edges >= 0)
assert torch.all(result_edges <= 1)
def test_get_lin_function():
# Default parameters
func = get_lin_function()
assert func(256) == 0.5
assert func(4096) == 1.15
# Custom parameters
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
assert custom_func(100) == 0.1
assert custom_func(1000) == 0.9
def test_get_schedule():
# Basic schedule
schedule = get_schedule(num_steps=10, image_seq_len=256)
assert len(schedule) == 10
assert all(0 <= x <= 1 for x in schedule)
# Test different sequence lengths
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
assert len(short_schedule) == 5
assert len(long_schedule) == 15
# Test with shift disabled
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
assert torch.allclose(
torch.tensor(unshifted_schedule),
torch.linspace(1, 1/10, 10)
)
def test_compute_density_for_timestep_sampling():
# Test uniform sampling
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
assert len(uniform_samples) == 100
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
# Test logit normal sampling
logit_normal_samples = compute_density_for_timestep_sampling(
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
)
assert len(logit_normal_samples) == 100
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
# Test mode sampling
mode_samples = compute_density_for_timestep_sampling(
"mode", batch_size=100, mode_scale=0.5
)
assert len(mode_samples) == 100
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
def test_get_sigmas():
# Create a mock noise scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
# Test with default parameters
timesteps = torch.tensor([100, 500, 900])
sigmas = get_sigmas(scheduler, timesteps, device)
# Check shape and basic properties
assert sigmas.shape[0] == 3
assert torch.all(sigmas >= 0)
# Test with different n_dim
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
assert sigmas_4d.ndim == 4
# Test with different dtype
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
assert sigmas_float16.dtype == torch.float16
def test_compute_loss_weighting_for_sd3():
# Prepare some mock sigmas
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test sigma_sqrt weighting
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
# Test cosmap weighting
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
bot = 1 - 2 * sigmas + 2 * sigmas**2
expected_cosmap = 2 / (math.pi * bot)
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
# Test default weighting
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
assert torch.all(default_weighting == 1)
def test_apply_model_prediction_type():
# Create mock args and tensors
class MockArgs:
model_prediction_type = "raw"
weighting_scheme = "sigma_sqrt"
args = MockArgs()
model_pred = torch.tensor([1.0, 2.0, 3.0])
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test raw prediction type
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(raw_pred == model_pred)
assert raw_weighting is None
# Test additive prediction type
args.model_prediction_type = "additive"
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(additive_pred == model_pred + noisy_model_input)
# Test sigma scaled prediction type
args.model_prediction_type = "sigma_scaled"
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
assert sigma_weighting is not None
def test_retrieve_timesteps():
# Create a mock scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
# Test with num_inference_steps
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
assert len(timesteps) == 50
assert n_steps == 50
# Test error handling with simultaneous timesteps and sigmas
with pytest.raises(ValueError):
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
def test_get_noisy_model_input_and_timesteps():
# Create a mock args and setup
class MockArgs:
timestep_sampling = "uniform"
weighting_scheme = "sigma_sqrt"
sigmoid_scale = 1.0
discrete_flow_shift = 6.0
args = MockArgs()
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device('cpu')
# Prepare mock latents and noise
latents = torch.randn(4, 16, 64, 64)
noise = torch.randn_like(latents)
# Test uniform sampling
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
# Validate output shapes and types
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]
assert noisy_input.dtype == torch.float32
assert timesteps.dtype == torch.float32
# Test different sampling methods
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
for method in sampling_methods:
args.timestep_sampling = method
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
args, scheduler, latents, noise, device, torch.float32
)
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]

View File

@@ -0,0 +1,112 @@
import torch
from torch.nn.modules import conv
from library import lumina_util
def test_unpack_latents():
# Create a test tensor
# Shape: [batch, height*width, channels*patch_height*patch_width]
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
packed_latent_height = 2
packed_latent_width = 2
# Unpack the latents
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
# Check output shape
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
assert unpacked.shape == (2, 4, 4, 4)
def test_pack_latents():
# Create a test tensor
# Shape: [batch, channels, height*patch_height, width*patch_width]
x = torch.randn(2, 4, 4, 4)
# Pack the latents
packed = lumina_util.pack_latents(x)
# Check output shape
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
assert packed.shape == (2, 4, 16)
def test_convert_diffusers_sd_to_alpha_vllm():
num_double_blocks = 2
# Predefined test cases based on the actual conversion map
test_cases = [
# Static key conversions with possible list mappings
{
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
"expected_converted_keys": ["cap_embedder.0.weight"],
},
{
"original_keys": ["patch_embedder.proj.weight"],
"original_pattern": ["patch_embedder.proj.weight"],
"expected_converted_keys": ["x_embedder.weight"],
},
{
"original_keys": ["transformer_blocks.0.norm1.weight"],
"original_pattern": ["transformer_blocks.().norm1.weight"],
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
},
]
for test_case in test_cases:
for original_key, original_pattern, expected_converted_key in zip(
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
):
# Create test state dict
test_sd = {original_key: torch.randn(10, 10)}
# Convert the state dict
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
# Verify conversion (handle both string and list keys)
# Find the correct converted key
match_found = False
if expected_converted_key in converted_sd:
# Verify tensor preservation
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
f"Tensor mismatch for {original_key}"
)
match_found = True
break
assert match_found, f"Failed to convert {original_key}"
# Ensure original key is also present
assert original_key in converted_sd
# Test with block-specific keys
block_specific_cases = [
{
"original_pattern": "transformer_blocks.().norm1.weight",
"converted_pattern": "layers.().attention_norm1.weight",
}
]
for case in block_specific_cases:
for block_idx in range(2): # Test multiple block indices
# Prepare block-specific keys
block_original_key = case["original_pattern"].replace("()", str(block_idx))
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
print(block_original_key, block_converted_key)
# Create test state dict
test_sd = {block_original_key: torch.randn(10, 10)}
# Convert the state dict
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
# Verify conversion
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
f"Tensor mismatch for block key {block_original_key}"
)
# Ensure original key is also present
assert block_original_key in converted_sd

View File

@@ -0,0 +1,227 @@
import os
import tempfile
import torch
import numpy as np
from unittest.mock import patch
from transformers import Gemma2Model
from library.strategy_lumina import (
LuminaTokenizeStrategy,
LuminaTextEncodingStrategy,
LuminaTextEncoderOutputsCachingStrategy,
LuminaLatentsCachingStrategy,
)
class SimpleMockGemma2Model:
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
def __init__(self, hidden_size=2304):
self.device = torch.device("cpu")
self._hidden_size = hidden_size
self._orig_mod = self # For dynamic compilation compatibility
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
# Create a mock output object with hidden states
batch_size, seq_len = input_ids.shape
hidden_size = self._hidden_size
class MockOutput:
def __init__(self, hidden_states):
self.hidden_states = hidden_states
mock_hidden_states = [
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
for _ in range(3) # Mimic multiple layers of hidden states
]
return MockOutput(mock_hidden_states)
def test_lumina_tokenize_strategy():
# Test default initialization
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
assert tokenize_strategy.max_length == 256
assert tokenize_strategy.tokenizer.padding_side == "right"
# Test tokenization of a single string
text = "Hello"
tokens, attention_mask = tokenize_strategy.tokenize(text)
assert tokens.ndim == 2
assert attention_mask.ndim == 2
assert tokens.shape == attention_mask.shape
assert tokens.shape[1] == 256 # max_length
# Test tokenize_with_weights
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
assert len(weights) == 1
assert torch.all(weights[0] == 1)
def test_lumina_text_encoding_strategy():
# Create strategies
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
encoding_strategy = LuminaTextEncodingStrategy()
# Create a mock model
mock_model = SimpleMockGemma2Model()
# Patch the isinstance check to accept our simple mock
original_isinstance = isinstance
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
def custom_isinstance(obj, class_or_tuple):
if obj is mock_model and class_or_tuple is Gemma2Model:
return True
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
return True
return original_isinstance(obj, class_or_tuple)
mock_isinstance.side_effect = custom_isinstance
# Prepare sample text
text = "Test encoding strategy"
tokens, attention_mask = tokenize_strategy.tokenize(text)
# Perform encoding
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
tokenize_strategy, [mock_model], (tokens, attention_mask)
)
# Validate outputs
assert original_isinstance(hidden_states, torch.Tensor)
assert original_isinstance(input_ids, torch.Tensor)
assert original_isinstance(attention_masks, torch.Tensor)
# Check the shape of the second-to-last hidden state
assert hidden_states.ndim == 3
# Test weighted encoding (which falls back to standard encoding for Lumina)
weights = [torch.ones_like(tokens)]
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
)
# For the mock, we can't guarantee identical outputs since each call returns random tensors
# Instead, check that the outputs have the same shape and are tensors
assert hidden_states_w.shape == hidden_states.shape
assert original_isinstance(hidden_states_w, torch.Tensor)
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
def test_lumina_text_encoder_outputs_caching_strategy():
# Create a temporary directory for caching
with tempfile.TemporaryDirectory() as tmpdir:
# Create a cache file path
cache_file = os.path.join(tmpdir, "test_outputs.npz")
# Create the caching strategy
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
cache_to_disk=True,
batch_size=1,
skip_disk_cache_validity_check=False,
)
# Create a mock class for ImageInfo
class MockImageInfo:
def __init__(self, caption, system_prompt, cache_path):
self.caption = caption
self.system_prompt = system_prompt
self.text_encoder_outputs_npz = cache_path
# Create a sample input info
image_info = MockImageInfo("Test caption", "", cache_file)
# Simulate a batch
batch = [image_info]
# Create mock strategies and model
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
encoding_strategy = LuminaTextEncodingStrategy()
mock_model = SimpleMockGemma2Model()
# Patch the isinstance check to accept our simple mock
original_isinstance = isinstance
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
def custom_isinstance(obj, class_or_tuple):
if obj is mock_model and class_or_tuple is Gemma2Model:
return True
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
return True
return original_isinstance(obj, class_or_tuple)
mock_isinstance.side_effect = custom_isinstance
# Call cache_batch_outputs
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
# Verify the npz file was created
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
# Verify the is_disk_cached_outputs_expected method
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
# Test loading from npz
loaded_data = caching_strategy.load_outputs_npz(cache_file)
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
def test_lumina_latents_caching_strategy():
# Create a temporary directory for caching
with tempfile.TemporaryDirectory() as tmpdir:
# Prepare a mock absolute path
abs_path = os.path.join(tmpdir, "test_image.png")
# Use smaller image size for faster testing
image_size = (64, 64)
# Create a smaller dummy image for testing
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
# Create the caching strategy
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
# Create a simple mock VAE
class MockVAE:
def __init__(self):
self.device = torch.device("cpu")
self.dtype = torch.float32
def encode(self, x):
# Return smaller encoded tensor for faster processing
encoded = torch.randn(1, 4, 8, 8, device=x.device)
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
# Prepare a mock batch
class MockImageInfo:
def __init__(self, path, image):
self.absolute_path = path
self.image = image
self.image_path = path
self.bucket_reso = image_size
self.resized_size = image_size
self.resize_interpolation = "lanczos"
# Specify full path to the latents npz file
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
batch = [MockImageInfo(abs_path, test_image)]
# Call cache_batch_latents
mock_vae = MockVAE()
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
# Generate the expected npz path
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
# Verify the file was created
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
# Verify is_disk_cached_latents_expected
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
# Test loading from disk
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
assert len(loaded_data) == 5 # Check for 5 expected elements

View File

@@ -0,0 +1,173 @@
import pytest
import torch
from unittest.mock import MagicMock, patch
import argparse
from library import lumina_models, lumina_util
from lumina_train_network import LuminaNetworkTrainer
@pytest.fixture
def lumina_trainer():
return LuminaNetworkTrainer()
@pytest.fixture
def mock_args():
args = MagicMock()
args.pretrained_model_name_or_path = "test_path"
args.disable_mmap_load_safetensors = False
args.use_flash_attn = False
args.use_sage_attn = False
args.fp8_base = False
args.blocks_to_swap = None
args.gemma2 = "test_gemma2_path"
args.ae = "test_ae_path"
args.cache_text_encoder_outputs = True
args.cache_text_encoder_outputs_to_disk = False
args.network_train_unet_only = False
return args
@pytest.fixture
def mock_accelerator():
accelerator = MagicMock()
accelerator.device = torch.device("cpu")
accelerator.prepare.side_effect = lambda x, **kwargs: x
accelerator.unwrap_model.side_effect = lambda x: x
return accelerator
def test_assert_extra_args(lumina_trainer, mock_args):
train_dataset_group = MagicMock()
train_dataset_group.verify_bucket_reso_steps = MagicMock()
val_dataset_group = MagicMock()
val_dataset_group.verify_bucket_reso_steps = MagicMock()
# Test with default settings
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
# Verify verify_bucket_reso_steps was called for both groups
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
# Check text encoder output caching
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
assert mock_args.cache_text_encoder_outputs is True
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
# Patch lumina_util methods
with (
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
patch("library.lumina_util.load_ae") as mock_load_ae
):
# Create mock models
mock_model = MagicMock(spec=lumina_models.NextDiT)
mock_model.dtype = torch.float32
mock_gemma2 = MagicMock()
mock_ae = MagicMock()
mock_load_lumina_model.return_value = mock_model
mock_load_gemma2.return_value = mock_gemma2
mock_load_ae.return_value = mock_ae
# Test load_target_model
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
# Verify calls and return values
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
assert gemma2_list == [mock_gemma2]
assert ae == mock_ae
assert model == mock_model
# Verify load calls
mock_load_lumina_model.assert_called_once()
mock_load_gemma2.assert_called_once()
mock_load_ae.assert_called_once()
def test_get_strategies(lumina_trainer, mock_args):
# Test tokenize strategy
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
# Test latents caching strategy
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
# Test text encoding strategy
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
# Call assert_extra_args to set train_gemma2
train_dataset_group = MagicMock()
train_dataset_group.verify_bucket_reso_steps = MagicMock()
val_dataset_group = MagicMock()
val_dataset_group.verify_bucket_reso_steps = MagicMock()
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
# With text encoder caching enabled
mock_args.skip_cache_check = False
mock_args.text_encoder_batch_size = 16
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
assert strategy.cache_to_disk is False # based on mock_args
# With text encoder caching disabled
mock_args.cache_text_encoder_outputs = False
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
assert strategy is None
def test_noise_scheduler(lumina_trainer, mock_args):
device = torch.device("cpu")
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
assert noise_scheduler.num_train_timesteps == 1000
assert hasattr(lumina_trainer, "noise_scheduler_copy")
def test_sai_model_spec(lumina_trainer, mock_args):
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
mock_get_spec.return_value = "test_spec"
spec = lumina_trainer.get_sai_model_spec(mock_args)
assert spec == "test_spec"
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
def test_update_metadata(lumina_trainer, mock_args):
metadata = {}
lumina_trainer.update_metadata(metadata, mock_args)
assert "ss_weighting_scheme" in metadata
assert "ss_logit_mean" in metadata
assert "ss_logit_std" in metadata
assert "ss_mode_scale" in metadata
assert "ss_timestep_sampling" in metadata
assert "ss_sigmoid_scale" in metadata
assert "ss_model_prediction_type" in metadata
assert "ss_discrete_flow_shift" in metadata
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
# Test with text encoder output caching, but not training text encoder
mock_args.cache_text_encoder_outputs = True
with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=False):
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is True
# Test with text encoder output caching and training text encoder
with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=True):
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is False
# Test with no text encoder output caching
mock_args.cache_text_encoder_outputs = False
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
assert result is False