mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Fix timestep/timestep refactor
This commit is contained in:
@@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
get_noisy_model_input_and_timestep,
|
||||
)
|
||||
|
||||
# Mock classes and functions
|
||||
@@ -66,13 +66,13 @@ def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "uniform"
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
assert timestep.dtype == dtype
|
||||
|
||||
|
||||
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
@@ -80,11 +80,11 @@ def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
@@ -93,11 +93,11 @@ def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.discrete_flow_shift = 3.1582
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
@@ -105,11 +105,11 @@ def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
@@ -126,13 +126,13 @@ def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
args.mode_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test IP noise options
|
||||
@@ -141,11 +141,11 @@ def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
@@ -153,11 +153,11 @@ def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma_random_strength = True
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert timestep.shape == (latents.shape[0],)
|
||||
assert sigma.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test different data types
|
||||
@@ -176,11 +176,11 @@ def test_different_batch_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(5, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (5,)
|
||||
assert sigmas.shape == (5, 1, 1, 1)
|
||||
assert timestep.shape == (5,)
|
||||
assert sigma.shape == (5, 1, 1, 1)
|
||||
|
||||
|
||||
# Test different image sizes
|
||||
@@ -189,11 +189,11 @@ def test_different_image_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(2, 4, 16, 16)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert sigmas.shape == (2, 1, 1, 1)
|
||||
assert timestep.shape == (2,)
|
||||
assert sigma.shape == (2, 1, 1, 1)
|
||||
|
||||
|
||||
# Test edge cases
|
||||
@@ -203,7 +203,7 @@ def test_zero_batch_size(args, noise_scheduler, device):
|
||||
noise = torch.randn(0, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
|
||||
def test_different_timestep_count(args, device):
|
||||
@@ -212,9 +212,9 @@ def test_different_timestep_count(args, device):
|
||||
noise = torch.randn(2, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
noisy_input, timestep, sigma = get_noisy_model_input_and_timestep(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert timestep.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
assert torch.all(timestep < 500)
|
||||
|
||||
Reference in New Issue
Block a user