mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Fix multi-resolution support in cached files
This commit is contained in:
@@ -52,9 +52,10 @@ class TestCDCPreprocessorIntegration:
|
||||
# Verify files were created
|
||||
assert files_saved == 10
|
||||
|
||||
# Verify first CDC file structure (with config hash)
|
||||
# Verify first CDC file structure (with config hash and latent shape)
|
||||
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
|
||||
latent_shape = (16, 4, 4)
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
|
||||
assert cdc_path.exists()
|
||||
|
||||
import numpy as np
|
||||
@@ -112,12 +113,12 @@ class TestCDCPreprocessorIntegration:
|
||||
assert files_saved == 10
|
||||
|
||||
import numpy as np
|
||||
# Check shapes are stored in individual files (with config hash)
|
||||
# Check shapes are stored in individual files (with config hash and latent shape)
|
||||
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
|
||||
)
|
||||
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
|
||||
)
|
||||
data_0 = np.load(cdc_path_0)
|
||||
data_5 = np.load(cdc_path_5)
|
||||
@@ -278,8 +279,9 @@ class TestCDCEndToEnd:
|
||||
batch_t = torch.rand(batch_size)
|
||||
latents_npz_paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu")
|
||||
# Get Γ_b components (pass latent_shape for multi-resolution support)
|
||||
latent_shape = (16, 4, 4)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(latents_npz_paths_batch, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
@@ -45,7 +45,8 @@ class TestCDCPreprocessor:
|
||||
|
||||
# Verify first CDC file structure
|
||||
latents_npz_path = str(tmp_path / "test_image_0_0004x0004_flux.npz")
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash))
|
||||
latent_shape = (16, 4, 4)
|
||||
cdc_path = Path(CDCPreprocessor.get_cdc_npz_path(latents_npz_path, preprocessor.config_hash, latent_shape))
|
||||
assert cdc_path.exists()
|
||||
|
||||
data = np.load(cdc_path)
|
||||
@@ -100,10 +101,10 @@ class TestCDCPreprocessor:
|
||||
|
||||
# Check shapes are stored in individual files
|
||||
cdc_path_0 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash
|
||||
str(tmp_path / "test_image_0_0004x0004_flux.npz"), preprocessor.config_hash, latent_shape=(16, 4, 4)
|
||||
)
|
||||
cdc_path_5 = CDCPreprocessor.get_cdc_npz_path(
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash
|
||||
str(tmp_path / "test_image_5_0008x0008_flux.npz"), preprocessor.config_hash, latent_shape=(16, 8, 8)
|
||||
)
|
||||
|
||||
data_0 = np.load(cdc_path_0)
|
||||
@@ -152,7 +153,8 @@ class TestGammaBDataset:
|
||||
gamma_b_dataset = GammaBDataset(device="cpu", config_hash=config_hash)
|
||||
|
||||
# Get components for first sample
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu")
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt([latents_npz_paths[0]], device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Check shapes
|
||||
assert eigvecs.shape[0] == 1 # batch size
|
||||
@@ -166,7 +168,8 @@ class TestGammaBDataset:
|
||||
|
||||
# Get Γ_b for paths [0, 2, 4]
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
|
||||
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
latent_shape = (16, 8, 8)
|
||||
eigenvectors, eigenvalues = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Check shapes
|
||||
assert eigenvectors.shape[0] == 3 # batch
|
||||
@@ -187,7 +190,8 @@ class TestGammaBDataset:
|
||||
|
||||
# Get Γ_b components
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[1], latents_npz_paths[2]]
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
@@ -204,7 +208,8 @@ class TestGammaBDataset:
|
||||
|
||||
# Get Γ_b components
|
||||
paths = [latents_npz_paths[1], latents_npz_paths[3]]
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
@@ -221,7 +226,8 @@ class TestGammaBDataset:
|
||||
|
||||
# Get Γ_b components
|
||||
paths = [latents_npz_paths[0], latents_npz_paths[2], latents_npz_paths[4]]
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu")
|
||||
latent_shape = (16, 8, 8)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, x, t)
|
||||
|
||||
@@ -269,7 +275,8 @@ class TestCDCEndToEnd:
|
||||
paths_batch = [latents_npz_paths[0], latents_npz_paths[5], latents_npz_paths[9]]
|
||||
|
||||
# Get Γ_b components
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu")
|
||||
latent_shape = (16, 4, 4)
|
||||
eigvecs, eigvals = gamma_b_dataset.get_gamma_b_sqrt(paths_batch, device="cpu", latent_shape=latent_shape)
|
||||
|
||||
# Compute geometry-aware noise
|
||||
sigma_t_x = gamma_b_dataset.compute_sigma_t_x(eigvecs, eigvals, batch_latents_flat, batch_t)
|
||||
|
||||
Reference in New Issue
Block a user