mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
221 lines
7.4 KiB
Python
221 lines
7.4 KiB
Python
import pytest
|
|
import torch
|
|
from unittest.mock import MagicMock, patch
|
|
from library.flux_train_utils import (
|
|
get_noisy_model_input_and_timesteps,
|
|
)
|
|
|
|
# Mock classes and functions
|
|
class MockNoiseScheduler:
|
|
def __init__(self, num_train_timesteps=1000):
|
|
self.config = MagicMock()
|
|
self.config.num_train_timesteps = num_train_timesteps
|
|
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
|
|
|
|
|
|
# Create fixtures for commonly used objects
|
|
@pytest.fixture
|
|
def args():
|
|
args = MagicMock()
|
|
args.timestep_sampling = "uniform"
|
|
args.weighting_scheme = "uniform"
|
|
args.logit_mean = 0.0
|
|
args.logit_std = 1.0
|
|
args.mode_scale = 1.0
|
|
args.sigmoid_scale = 1.0
|
|
args.discrete_flow_shift = 3.1582
|
|
args.ip_noise_gamma = None
|
|
args.ip_noise_gamma_random_strength = False
|
|
return args
|
|
|
|
|
|
@pytest.fixture
|
|
def noise_scheduler():
|
|
return MockNoiseScheduler(num_train_timesteps=1000)
|
|
|
|
|
|
@pytest.fixture
|
|
def latents():
|
|
return torch.randn(2, 4, 8, 8)
|
|
|
|
|
|
@pytest.fixture
|
|
def noise():
|
|
return torch.randn(2, 4, 8, 8)
|
|
|
|
|
|
@pytest.fixture
|
|
def device():
|
|
# return "cuda" if torch.cuda.is_available() else "cpu"
|
|
return "cpu"
|
|
|
|
|
|
# Mock the required functions
|
|
@pytest.fixture(autouse=True)
|
|
def mock_functions():
|
|
with (
|
|
patch("torch.sigmoid", side_effect=torch.sigmoid),
|
|
patch("torch.rand", side_effect=torch.rand),
|
|
patch("torch.randn", side_effect=torch.randn),
|
|
):
|
|
yield
|
|
|
|
|
|
# Test different timestep sampling methods
|
|
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
|
args.timestep_sampling = "uniform"
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
assert noisy_input.dtype == dtype
|
|
assert timesteps.dtype == dtype
|
|
|
|
|
|
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
|
args.timestep_sampling = "sigmoid"
|
|
args.sigmoid_scale = 1.0
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
|
|
|
|
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
|
args.timestep_sampling = "shift"
|
|
args.sigmoid_scale = 1.0
|
|
args.discrete_flow_shift = 3.1582
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
|
|
|
|
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
|
args.timestep_sampling = "flux_shift"
|
|
args.sigmoid_scale = 1.0
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
|
|
|
|
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
|
# Mock the necessary functions for this specific test
|
|
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
|
|
return_value=torch.tensor([0.3, 0.7], device=device)), \
|
|
patch("library.flux_train_utils.get_sigmas",
|
|
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
|
|
|
|
args.timestep_sampling = "other" # Will trigger the weighting scheme path
|
|
args.weighting_scheme = "uniform"
|
|
args.logit_mean = 0.0
|
|
args.logit_std = 1.0
|
|
args.mode_scale = 1.0
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
|
args, noise_scheduler, latents, noise, device, dtype
|
|
)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
|
|
|
|
# Test IP noise options
|
|
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
|
args.ip_noise_gamma = 0.5
|
|
args.ip_noise_gamma_random_strength = False
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
|
|
|
|
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
|
args.ip_noise_gamma = 0.1
|
|
args.ip_noise_gamma_random_strength = True
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (latents.shape[0],)
|
|
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
|
|
|
|
|
# Test different data types
|
|
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
|
|
dtype = torch.float16
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.dtype == dtype
|
|
assert timesteps.dtype == dtype
|
|
|
|
|
|
# Test different batch sizes
|
|
def test_different_batch_size(args, noise_scheduler, device):
|
|
latents = torch.randn(5, 4, 8, 8) # batch size of 5
|
|
noise = torch.randn(5, 4, 8, 8)
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (5,)
|
|
assert sigmas.shape == (5, 1, 1, 1)
|
|
|
|
|
|
# Test different image sizes
|
|
def test_different_image_size(args, noise_scheduler, device):
|
|
latents = torch.randn(2, 4, 16, 16) # larger image size
|
|
noise = torch.randn(2, 4, 16, 16)
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (2,)
|
|
assert sigmas.shape == (2, 1, 1, 1)
|
|
|
|
|
|
# Test edge cases
|
|
def test_zero_batch_size(args, noise_scheduler, device):
|
|
with pytest.raises(AssertionError): # expecting an error with zero batch size
|
|
latents = torch.randn(0, 4, 8, 8)
|
|
noise = torch.randn(0, 4, 8, 8)
|
|
dtype = torch.float32
|
|
|
|
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
|
|
def test_different_timestep_count(args, device):
|
|
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
|
|
latents = torch.randn(2, 4, 8, 8)
|
|
noise = torch.randn(2, 4, 8, 8)
|
|
dtype = torch.float32
|
|
|
|
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
|
|
|
assert noisy_input.shape == latents.shape
|
|
assert timesteps.shape == (2,)
|
|
# Check that timesteps are within the proper range
|
|
assert torch.all(timesteps < 500)
|