mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
296 lines
10 KiB
Python
296 lines
10 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from library.lumina_models import (
|
|
LuminaParams,
|
|
to_cuda,
|
|
to_cpu,
|
|
RopeEmbedder,
|
|
TimestepEmbedder,
|
|
modulate,
|
|
NextDiT,
|
|
)
|
|
|
|
cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
|
|
|
|
|
def test_lumina_params():
|
|
# Test default configuration
|
|
default_params = LuminaParams()
|
|
assert default_params.patch_size == 2
|
|
assert default_params.in_channels == 4
|
|
assert default_params.axes_dims == [36, 36, 36]
|
|
assert default_params.axes_lens == [300, 512, 512]
|
|
|
|
# Test 2B config
|
|
config_2b = LuminaParams.get_2b_config()
|
|
assert config_2b.dim == 2304
|
|
assert config_2b.in_channels == 16
|
|
assert config_2b.n_layers == 26
|
|
assert config_2b.n_heads == 24
|
|
assert config_2b.cap_feat_dim == 2304
|
|
|
|
# Test 7B config
|
|
config_7b = LuminaParams.get_7b_config()
|
|
assert config_7b.dim == 4096
|
|
assert config_7b.n_layers == 32
|
|
assert config_7b.n_heads == 32
|
|
assert config_7b.axes_dims == [64, 64, 64]
|
|
|
|
|
|
@cuda_required
|
|
def test_to_cuda_to_cpu():
|
|
# Test tensor conversion
|
|
x = torch.tensor([1, 2, 3])
|
|
x_cuda = to_cuda(x)
|
|
x_cpu = to_cpu(x_cuda)
|
|
assert x.cpu().tolist() == x_cpu.tolist()
|
|
|
|
# Test list conversion
|
|
list_data = [torch.tensor([1]), torch.tensor([2])]
|
|
list_cuda = to_cuda(list_data)
|
|
assert all(tensor.device.type == "cuda" for tensor in list_cuda)
|
|
|
|
list_cpu = to_cpu(list_cuda)
|
|
assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
|
|
|
|
# Test dict conversion
|
|
dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
|
|
dict_cuda = to_cuda(dict_data)
|
|
assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
|
|
|
|
dict_cpu = to_cpu(dict_cuda)
|
|
assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
|
|
|
|
|
|
def test_timestep_embedder():
|
|
# Test initialization
|
|
hidden_size = 256
|
|
freq_emb_size = 128
|
|
embedder = TimestepEmbedder(hidden_size, freq_emb_size)
|
|
assert embedder.frequency_embedding_size == freq_emb_size
|
|
|
|
# Test timestep embedding
|
|
t = torch.tensor([0.5, 1.0, 2.0])
|
|
emb_dim = freq_emb_size
|
|
embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
|
|
|
|
assert embeddings.shape == (3, emb_dim)
|
|
assert embeddings.dtype == torch.float32
|
|
|
|
# Ensure embeddings are unique for different input times
|
|
assert not torch.allclose(embeddings[0], embeddings[1])
|
|
|
|
# Test forward pass
|
|
t_emb = embedder(t)
|
|
assert t_emb.shape == (3, hidden_size)
|
|
|
|
|
|
def test_rope_embedder_simple():
|
|
rope_embedder = RopeEmbedder()
|
|
batch_size, seq_len = 2, 10
|
|
|
|
# Create position_ids with valid ranges for each axis
|
|
position_ids = torch.stack(
|
|
[
|
|
torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
|
|
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
|
|
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
freqs_cis = rope_embedder(position_ids)
|
|
# RoPE embeddings work in pairs, so output dimension is half of total axes_dims
|
|
expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
|
|
assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
|
|
|
|
|
|
def test_modulate():
|
|
# Test modulation with different scales
|
|
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
|
scale = torch.tensor([1.5, 2.0])
|
|
|
|
modulated_x = modulate(x, scale)
|
|
|
|
# Check that modulation scales correctly
|
|
# The function does x * (1 + scale), so:
|
|
# For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
|
|
expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
|
|
# Which equals: [[2.5, 5.0], [9.0, 12.0]]
|
|
|
|
assert torch.allclose(modulated_x, expected_x)
|
|
|
|
|
|
def test_nextdit_parameter_count_optimized():
|
|
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
|
# So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
|
model_small = NextDiT(
|
|
patch_size=2,
|
|
in_channels=4, # Smaller
|
|
dim=120, # 120 // 4 = 30
|
|
n_layers=2, # Much fewer layers
|
|
n_heads=4, # Fewer heads
|
|
n_kv_heads=2,
|
|
axes_dims=[10, 10, 10], # sum = 30
|
|
axes_lens=[10, 32, 32], # Smaller
|
|
)
|
|
param_count_small = model_small.parameter_count()
|
|
assert param_count_small > 0
|
|
|
|
# For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
|
|
model_medium = NextDiT(
|
|
patch_size=2,
|
|
in_channels=4,
|
|
dim=192, # 192 // 6 = 32
|
|
n_layers=4, # More layers
|
|
n_heads=6,
|
|
n_kv_heads=3,
|
|
axes_dims=[10, 11, 11], # sum = 32
|
|
axes_lens=[10, 32, 32],
|
|
)
|
|
param_count_medium = model_medium.parameter_count()
|
|
assert param_count_medium > param_count_small
|
|
print(f"Small model: {param_count_small:,} parameters")
|
|
print(f"Medium model: {param_count_medium:,} parameters")
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_precompute_freqs_cis():
|
|
# Test precompute_freqs_cis
|
|
dim = [16, 56, 56]
|
|
end = [1, 512, 512]
|
|
theta = 10000.0
|
|
|
|
freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
|
|
|
|
# Check number of frequency tensors
|
|
assert len(freqs_cis) == len(dim)
|
|
|
|
# Check each frequency tensor
|
|
for i, (d, e) in enumerate(zip(dim, end)):
|
|
assert freqs_cis[i].shape == (e, d // 2)
|
|
assert freqs_cis[i].dtype == torch.complex128
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_nextdit_patchify_and_embed():
|
|
"""Test the patchify_and_embed method which is crucial for training"""
|
|
# Create a small NextDiT model for testing
|
|
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
|
# For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
|
model = NextDiT(
|
|
patch_size=2,
|
|
in_channels=4,
|
|
dim=120, # 120 // 4 = 30
|
|
n_layers=1, # Minimal layers for faster testing
|
|
n_refiner_layers=1, # Minimal refiner layers
|
|
n_heads=4,
|
|
n_kv_heads=2,
|
|
axes_dims=[10, 10, 10], # sum = 30
|
|
axes_lens=[10, 32, 32],
|
|
cap_feat_dim=120, # Match dim for consistency
|
|
)
|
|
|
|
# Prepare test inputs
|
|
batch_size = 2
|
|
height, width = 64, 64 # Must be divisible by patch_size (2)
|
|
caption_seq_len = 8
|
|
|
|
# Create mock inputs
|
|
x = torch.randn(batch_size, 4, height, width) # Image latents
|
|
cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
|
|
cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
|
|
# Make second batch have shorter caption
|
|
cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
|
|
t = torch.randn(batch_size, 120) # Timestep embeddings
|
|
|
|
# Call patchify_and_embed
|
|
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
|
x, cap_feats, cap_mask, t
|
|
)
|
|
|
|
# Validate outputs
|
|
image_seq_len = (height // 2) * (width // 2) # patch_size = 2
|
|
expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
|
|
max_seq_len = max(expected_seq_lengths)
|
|
|
|
# Check joint hidden states shape
|
|
assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
|
|
assert joint_hidden_states.dtype == torch.float32
|
|
|
|
# Check attention mask shape and values
|
|
assert attention_mask.shape == (batch_size, max_seq_len)
|
|
assert attention_mask.dtype == torch.bool
|
|
# First batch should have all positions valid up to its sequence length
|
|
assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
|
|
assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
|
|
# Second batch should have all positions valid up to its sequence length
|
|
assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
|
|
assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
|
|
|
|
# Check freqs_cis shape
|
|
assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
|
|
|
|
# Check effective caption lengths
|
|
assert l_effective_cap_len == [caption_seq_len, 6]
|
|
|
|
# Check sequence lengths
|
|
assert seq_lengths == expected_seq_lengths
|
|
|
|
# Validate that the joint hidden states contain non-zero values where attention mask is True
|
|
for i in range(batch_size):
|
|
valid_positions = attention_mask[i]
|
|
# Check that valid positions have meaningful data (not all zeros)
|
|
valid_data = joint_hidden_states[i][valid_positions]
|
|
assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
|
|
|
|
# Check that invalid positions are zeros
|
|
if valid_positions.sum() < max_seq_len:
|
|
invalid_data = joint_hidden_states[i][~valid_positions]
|
|
assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
|
|
|
|
|
|
@torch.no_grad()
|
|
def test_nextdit_patchify_and_embed_edge_cases():
|
|
"""Test edge cases for patchify_and_embed"""
|
|
# Create minimal model
|
|
model = NextDiT(
|
|
patch_size=2,
|
|
in_channels=4,
|
|
dim=60, # 60 // 3 = 20
|
|
n_layers=1,
|
|
n_refiner_layers=1,
|
|
n_heads=3,
|
|
n_kv_heads=1,
|
|
axes_dims=[8, 6, 6], # sum = 20
|
|
axes_lens=[10, 16, 16],
|
|
cap_feat_dim=60,
|
|
)
|
|
|
|
# Test with empty captions (all masked)
|
|
batch_size = 1
|
|
height, width = 32, 32
|
|
caption_seq_len = 4
|
|
|
|
x = torch.randn(batch_size, 4, height, width)
|
|
cap_feats = torch.randn(batch_size, caption_seq_len, 60)
|
|
cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
|
|
t = torch.randn(batch_size, 60)
|
|
|
|
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
|
x, cap_feats, cap_mask, t
|
|
)
|
|
|
|
# With all captions masked, effective length should be 0
|
|
assert l_effective_cap_len == [0]
|
|
|
|
# Sequence length should just be the image sequence length
|
|
image_seq_len = (height // 2) * (width // 2)
|
|
assert seq_lengths == [image_seq_len]
|
|
|
|
# Joint hidden states should only contain image data
|
|
assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
|
|
assert attention_mask.shape == (batch_size, image_seq_len)
|
|
assert torch.all(attention_mask[0]) # All image positions should be valid
|