Files
Kohya-ss-sd-scripts/tests/library/test_lumina_train_util.py

229 lines
7.8 KiB
Python

import pytest
import torch
import math
from library.lumina_train_util import (
batchify,
time_shift,
get_lin_function,
get_schedule,
compute_density_for_timestep_sampling,
get_sigmas,
compute_loss_weighting_for_sd3,
get_noisy_model_input_and_timesteps,
apply_model_prediction_type,
retrieve_timesteps,
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
def test_batchify():
# Test case with no batch size specified
prompts = [{"prompt": "test1"}, {"prompt": "test2"}, {"prompt": "test3"}]
batchified = list(batchify(prompts))
assert len(batchified) == 1
assert len(batchified[0]) == 3
# Test case with batch size specified
batchified_sized = list(batchify(prompts, batch_size=2))
assert len(batchified_sized) == 2
assert len(batchified_sized[0]) == 2
assert len(batchified_sized[1]) == 1
# Test batching with prompts having same parameters
prompts_with_params = [
{"prompt": "test1", "width": 512, "height": 512},
{"prompt": "test2", "width": 512, "height": 512},
{"prompt": "test3", "width": 1024, "height": 1024},
]
batchified_params = list(batchify(prompts_with_params))
assert len(batchified_params) == 2
# Test invalid batch size
with pytest.raises(ValueError):
list(batchify(prompts, batch_size=0))
with pytest.raises(ValueError):
list(batchify(prompts, batch_size=-1))
def test_time_shift():
# Test standard parameters
t = torch.tensor([0.5])
mu = 1.0
sigma = 1.0
result = time_shift(mu, sigma, t)
assert 0 <= result <= 1
# Test with edge cases
t_edges = torch.tensor([0.0, 1.0])
result_edges = time_shift(1.0, 1.0, t_edges)
# Check that results are bounded within [0, 1]
assert torch.all(result_edges >= 0)
assert torch.all(result_edges <= 1)
def test_get_lin_function():
# Default parameters
func = get_lin_function()
assert func(256) == 0.5
assert func(4096) == 1.15
# Custom parameters
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
assert custom_func(100) == 0.1
assert custom_func(1000) == 0.9
def test_get_schedule():
# Basic schedule
schedule = get_schedule(num_steps=10, image_seq_len=256)
assert len(schedule) == 10
assert all(0 <= x <= 1 for x in schedule)
# Test different sequence lengths
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
assert len(short_schedule) == 5
assert len(long_schedule) == 15
# Test with shift disabled
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
assert torch.allclose(torch.tensor(unshifted_schedule), torch.linspace(1, 1 / 10, 10))
def test_compute_density_for_timestep_sampling():
# Test uniform sampling
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
assert len(uniform_samples) == 100
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
# Test logit normal sampling
logit_normal_samples = compute_density_for_timestep_sampling("logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0)
assert len(logit_normal_samples) == 100
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
# Test mode sampling
mode_samples = compute_density_for_timestep_sampling("mode", batch_size=100, mode_scale=0.5)
assert len(mode_samples) == 100
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
def test_get_sigmas():
# Create a mock noise scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device("cpu")
# Test with default parameters
timesteps = torch.tensor([100, 500, 900])
sigmas = get_sigmas(scheduler, timesteps, device)
# Check shape and basic properties
assert sigmas.shape[0] == 3
assert torch.all(sigmas >= 0)
# Test with different n_dim
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
assert sigmas_4d.ndim == 4
# Test with different dtype
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
assert sigmas_float16.dtype == torch.float16
def test_compute_loss_weighting_for_sd3():
# Prepare some mock sigmas
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test sigma_sqrt weighting
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
# Test cosmap weighting
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
bot = 1 - 2 * sigmas + 2 * sigmas**2
expected_cosmap = 2 / (math.pi * bot)
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
# Test default weighting
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
assert torch.all(default_weighting == 1)
def test_apply_model_prediction_type():
# Create mock args and tensors
class MockArgs:
model_prediction_type = "raw"
weighting_scheme = "sigma_sqrt"
args = MockArgs()
model_pred = torch.tensor([1.0, 2.0, 3.0])
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
sigmas = torch.tensor([0.1, 0.5, 1.0])
# Test raw prediction type
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(raw_pred == model_pred)
assert raw_weighting is None
# Test additive prediction type
args.model_prediction_type = "additive"
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(additive_pred == model_pred + noisy_model_input)
# Test sigma scaled prediction type
args.model_prediction_type = "sigma_scaled"
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
assert sigma_weighting is not None
def test_retrieve_timesteps():
# Create a mock scheduler
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
# Test with num_inference_steps
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
assert len(timesteps) == 50
assert n_steps == 50
# Test error handling with simultaneous timesteps and sigmas
with pytest.raises(ValueError):
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
def test_get_noisy_model_input_and_timesteps():
# Create a mock args and setup
class MockArgs:
timestep_sampling = "uniform"
weighting_scheme = "sigma_sqrt"
sigmoid_scale = 1.0
discrete_flow_shift = 6.0
ip_noise_gamma = True
ip_noise_gamma_random_strength = 0.01
args = MockArgs()
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
device = torch.device("cpu")
# Prepare mock latents and noise
latents = torch.randn(4, 16, 64, 64)
noise = torch.randn_like(latents)
# Test uniform sampling
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
# Validate output shapes and types
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]
assert noisy_input.dtype == torch.float32
assert timesteps.dtype == torch.float32
# Test different sampling methods
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
for method in sampling_methods:
args.timestep_sampling = method
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
assert noisy_input.shape == latents.shape
assert timesteps.shape[0] == latents.shape[0]