Add more tests

This commit is contained in:
rockerBOO
2025-06-16 17:22:32 -04:00
parent 9f95d4f347
commit 4c8ebf7293
2 changed files with 107 additions and 0 deletions

View File

@@ -223,3 +223,42 @@ def test_different_device_compatibility(loss, timesteps, noise_scheduler):
noise_scheduler.get_snr_for_timestep.return_value = snr_tensor
result = apply_snr_weight(loss_on_device, timesteps, noise_scheduler, gamma)
# Additional tests for new functionality
def test_apply_snr_weight_with_image_size(loss, timesteps, noise_scheduler):
"""Test SNR weight application with image size consideration"""
gamma = 5.0
image_sizes = [None, 64, (256, 256)]
for image_size in image_sizes:
result = apply_snr_weight(
loss,
timesteps,
noise_scheduler,
gamma,
v_prediction=False,
image_size=image_size
)
# Allow for broadcasting
assert result.shape[0] == loss.shape[0]
assert result.dtype == loss.dtype
def test_apply_debiased_estimation_variations(loss, timesteps, noise_scheduler):
"""Test debiased estimation with different image sizes and prediction types"""
image_sizes = [None, 64, (256, 256)]
prediction_types = [True, False]
for image_size in image_sizes:
for v_prediction in prediction_types:
result = apply_debiased_estimation(
loss,
timesteps,
noise_scheduler,
v_prediction=v_prediction,
image_size=image_size
)
# Allow for broadcasting
assert result.shape[0] == loss.shape[0]
assert result.dtype == loss.dtype

View File

@@ -1,6 +1,8 @@
import pytest
import torch
import math
from unittest.mock import MagicMock, patch
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
)
@@ -218,3 +220,69 @@ def test_different_timestep_count(args, device):
assert timesteps.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)
# New tests for dynamic timestep shifting
def test_dynamic_timestep_shifting(device):
"""Test the dynamic timestep shifting functionality"""
# Create a scheduler with dynamic shifting enabled
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=1.0,
use_dynamic_shifting=True
)
# Test different image sizes
test_sizes = [
(64, 64), # Small image
(256, 256), # Medium image
(512, 512), # Large image
(1024, 1024) # Very large image
]
for image_size in test_sizes:
# Simulate setting timesteps for inference
mu = math.log(1 + (image_size[0] * image_size[1]) / (256 * 256))
scheduler.set_timesteps(num_inference_steps=50, mu=mu)
# Check that sigmas have been dynamically shifted
assert len(scheduler.sigmas) == 51 # num_inference_steps + 1
assert scheduler.sigmas[0] <= 1.0 # Maximum sigma should be <= 1
assert scheduler.sigmas[-1] == 0.0 # Last sigma should always be 0
def test_sigma_generation_methods():
"""Test different sigma generation methods"""
# Test Karras sigmas
scheduler_karras = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
use_karras_sigmas=True
)
scheduler_karras.set_timesteps(num_inference_steps=50)
assert len(scheduler_karras.sigmas) == 51
# Test Exponential sigmas
scheduler_exp = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
use_exponential_sigmas=True
)
scheduler_exp.set_timesteps(num_inference_steps=50)
assert len(scheduler_exp.sigmas) == 51
def test_snr_calculation():
"""Test the SNR calculation method"""
scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000,
shift=1.0
)
# Prepare test timesteps
timesteps = torch.tensor([200, 600], dtype=torch.int32)
# Test with different image sizes
test_sizes = [None, 64, (256, 256)]
for image_size in test_sizes:
snr_values = scheduler.get_snr_for_timestep(timesteps, image_size)
# Check basic properties
assert snr_values.shape == torch.Size([2])
assert torch.all(snr_values >= 0) # SNR should always be non-negative