Fix timestep/timestep refactor

This commit is contained in:
rockerBOO
2025-04-28 16:11:12 -04:00
parent 61e3083945
commit 10ce29f4fe

View File

@@ -2,7 +2,7 @@ import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
get_noisy_model_input_and_timestep,
)
# Mock classes and functions
@@ -66,13 +66,13 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
assert timestep.dtype == dtype
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
@@ -80,11 +80,11 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
@@ -93,11 +93,11 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
@@ -105,11 +105,11 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
@@ -126,13 +126,13 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(
args, noise_scheduler, latents, noise, device, dtype
)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
# Test IP noise options
@@ -141,11 +141,11 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
@@ -153,11 +153,11 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert timestep.shape == (latents.shape[0],)
assert sigma.shape == (latents.shape[0], 1, 1, 1)
# Test different data types
@@ -176,11 +176,11 @@ def test_different_batch_size(args, noise_scheduler, device):
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
assert sigmas.shape == (5, 1, 1, 1)
assert timestep.shape == (5,)
assert sigma.shape == (5, 1, 1, 1)
# Test different image sizes
@@ -189,11 +189,11 @@ def test_different_image_size(args, noise_scheduler, device):
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert sigmas.shape == (2, 1, 1, 1)
assert timestep.shape == (2,)
assert sigma.shape == (2, 1, 1, 1)
# Test edge cases
@@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
@@ -212,9 +212,9 @@ def test_different_timestep_count(args, device):
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert timestep.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)
assert torch.all(timestep < 500)