mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
229 lines
7.8 KiB
Python
229 lines
7.8 KiB
Python
import pytest
|
|
import torch
|
|
import math
|
|
|
|
from library.lumina_train_util import (
|
|
batchify,
|
|
time_shift,
|
|
get_lin_function,
|
|
get_schedule,
|
|
compute_density_for_timestep_sampling,
|
|
get_sigmas,
|
|
compute_loss_weighting_for_sd3,
|
|
get_noisy_model_input_and_timesteps,
|
|
apply_model_prediction_type,
|
|
retrieve_timesteps,
|
|
)
|
|
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
|
|
|
|
|
def test_batchify():
|
|
# Test case with no batch size specified
|
|
prompts = [{"prompt": "test1"}, {"prompt": "test2"}, {"prompt": "test3"}]
|
|
batchified = list(batchify(prompts))
|
|
assert len(batchified) == 1
|
|
assert len(batchified[0]) == 3
|
|
|
|
# Test case with batch size specified
|
|
batchified_sized = list(batchify(prompts, batch_size=2))
|
|
assert len(batchified_sized) == 2
|
|
assert len(batchified_sized[0]) == 2
|
|
assert len(batchified_sized[1]) == 1
|
|
|
|
# Test batching with prompts having same parameters
|
|
prompts_with_params = [
|
|
{"prompt": "test1", "width": 512, "height": 512},
|
|
{"prompt": "test2", "width": 512, "height": 512},
|
|
{"prompt": "test3", "width": 1024, "height": 1024},
|
|
]
|
|
batchified_params = list(batchify(prompts_with_params))
|
|
assert len(batchified_params) == 2
|
|
|
|
# Test invalid batch size
|
|
with pytest.raises(ValueError):
|
|
list(batchify(prompts, batch_size=0))
|
|
with pytest.raises(ValueError):
|
|
list(batchify(prompts, batch_size=-1))
|
|
|
|
|
|
def test_time_shift():
|
|
# Test standard parameters
|
|
t = torch.tensor([0.5])
|
|
mu = 1.0
|
|
sigma = 1.0
|
|
result = time_shift(mu, sigma, t)
|
|
assert 0 <= result <= 1
|
|
|
|
# Test with edge cases
|
|
t_edges = torch.tensor([0.0, 1.0])
|
|
result_edges = time_shift(1.0, 1.0, t_edges)
|
|
|
|
# Check that results are bounded within [0, 1]
|
|
assert torch.all(result_edges >= 0)
|
|
assert torch.all(result_edges <= 1)
|
|
|
|
|
|
def test_get_lin_function():
|
|
# Default parameters
|
|
func = get_lin_function()
|
|
assert func(256) == 0.5
|
|
assert func(4096) == 1.15
|
|
|
|
# Custom parameters
|
|
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
|
|
assert custom_func(100) == 0.1
|
|
assert custom_func(1000) == 0.9
|
|
|
|
|
|
def test_get_schedule():
|
|
# Basic schedule
|
|
schedule = get_schedule(num_steps=10, image_seq_len=256)
|
|
assert len(schedule) == 10
|
|
assert all(0 <= x <= 1 for x in schedule)
|
|
|
|
# Test different sequence lengths
|
|
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
|
|
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
|
|
assert len(short_schedule) == 5
|
|
assert len(long_schedule) == 15
|
|
|
|
# Test with shift disabled
|
|
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
|
|
assert torch.allclose(torch.tensor(unshifted_schedule), torch.linspace(1, 1 / 10, 10))
|
|
|
|
|
|
def test_compute_density_for_timestep_sampling():
|
|
# Test uniform sampling
|
|
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
|
|
assert len(uniform_samples) == 100
|
|
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
|
|
|
|
# Test logit normal sampling
|
|
logit_normal_samples = compute_density_for_timestep_sampling("logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0)
|
|
assert len(logit_normal_samples) == 100
|
|
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
|
|
|
|
# Test mode sampling
|
|
mode_samples = compute_density_for_timestep_sampling("mode", batch_size=100, mode_scale=0.5)
|
|
assert len(mode_samples) == 100
|
|
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
|
|
|
|
|
|
def test_get_sigmas():
|
|
# Create a mock noise scheduler
|
|
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
|
device = torch.device("cpu")
|
|
|
|
# Test with default parameters
|
|
timesteps = torch.tensor([100, 500, 900])
|
|
sigmas = get_sigmas(scheduler, timesteps, device)
|
|
|
|
# Check shape and basic properties
|
|
assert sigmas.shape[0] == 3
|
|
assert torch.all(sigmas >= 0)
|
|
|
|
# Test with different n_dim
|
|
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
|
|
assert sigmas_4d.ndim == 4
|
|
|
|
# Test with different dtype
|
|
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
|
|
assert sigmas_float16.dtype == torch.float16
|
|
|
|
|
|
def test_compute_loss_weighting_for_sd3():
|
|
# Prepare some mock sigmas
|
|
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
|
|
|
# Test sigma_sqrt weighting
|
|
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
|
|
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
|
|
|
|
# Test cosmap weighting
|
|
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
|
|
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
|
expected_cosmap = 2 / (math.pi * bot)
|
|
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
|
|
|
|
# Test default weighting
|
|
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
|
|
assert torch.all(default_weighting == 1)
|
|
|
|
|
|
def test_apply_model_prediction_type():
|
|
# Create mock args and tensors
|
|
class MockArgs:
|
|
model_prediction_type = "raw"
|
|
weighting_scheme = "sigma_sqrt"
|
|
|
|
args = MockArgs()
|
|
model_pred = torch.tensor([1.0, 2.0, 3.0])
|
|
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
|
|
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
|
|
|
# Test raw prediction type
|
|
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
|
assert torch.all(raw_pred == model_pred)
|
|
assert raw_weighting is None
|
|
|
|
# Test additive prediction type
|
|
args.model_prediction_type = "additive"
|
|
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
|
assert torch.all(additive_pred == model_pred + noisy_model_input)
|
|
|
|
# Test sigma scaled prediction type
|
|
args.model_prediction_type = "sigma_scaled"
|
|
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
|
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
|
|
assert sigma_weighting is not None
|
|
|
|
|
|
def test_retrieve_timesteps():
|
|
# Create a mock scheduler
|
|
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
|
|
|
# Test with num_inference_steps
|
|
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
|
|
assert len(timesteps) == 50
|
|
assert n_steps == 50
|
|
|
|
# Test error handling with simultaneous timesteps and sigmas
|
|
with pytest.raises(ValueError):
|
|
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
|
|
|
|
|
|
def test_get_noisy_model_input_and_timesteps():
|
|
# Create a mock args and setup
|
|
class MockArgs:
|
|
timestep_sampling = "uniform"
|
|
weighting_scheme = "sigma_sqrt"
|
|
sigmoid_scale = 1.0
|
|
discrete_flow_shift = 6.0
|
|
ip_noise_gamma = True
|
|
ip_noise_gamma_random_strength = 0.01
|
|
|
|
args = MockArgs()
|
|
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
|
device = torch.device("cpu")
|
|
|
|
# Prepare mock latents and noise
|
|
latents = torch.randn(4, 16, 64, 64)
|
|
noise = torch.randn_like(latents)
|
|
|
|
# Test uniform sampling
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
|
|
|
|
# Validate output shapes and types
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape[0] == latents.shape[0]
|
|
assert noisy_input.dtype == torch.float32
|
|
assert timesteps.dtype == torch.float32
|
|
|
|
# Test different sampling methods
|
|
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
|
|
for method in sampling_methods:
|
|
args.timestep_sampling = method
|
|
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(args, scheduler, latents, noise, device, torch.float32)
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape[0] == latents.shape[0]
|