Fix tests

This commit is contained in:
rockerBOO
2025-11-17 11:34:09 -05:00
parent cc0e4acf1b
commit 4888327caa

View File

@@ -2,7 +2,7 @@ import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
get_noisy_model_input_and_timestep,
)
# Mock classes and functions
@@ -66,7 +66,7 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -80,7 +80,7 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -93,7 +93,7 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -105,7 +105,7 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -126,7 +126,7 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, device, dtype
)
@@ -141,7 +141,7 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -153,7 +153,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
@@ -164,7 +164,7 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
@@ -176,7 +176,7 @@ def test_different_batch_size(args, noise_scheduler, device):
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
@@ -189,7 +189,7 @@ def test_different_image_size(args, noise_scheduler, device):
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
@@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
@@ -212,7 +212,7 @@ def test_different_timestep_count(args, device):
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)