Compare commits

...

5 Commits

Author SHA1 Message Date
Dave Lage
1fb482dab5 Merge 7a08c52aa4 into a5a162044c 2025-11-04 02:47:24 +00:00
rockerBOO
7a08c52aa4 Add error if with CDC if cache_latents or cache_latents_to_disk is not set 2025-11-03 21:47:15 -05:00
rockerBOO
377299851a Fix cdc cache file validation 2025-11-02 23:22:10 -05:00
Kohya S.
a5a162044c Merge pull request #2226 from kohya-ss/fix-hunyuan-image-batch-gen-error
fix: error on batch generation closes #2209
2025-10-15 21:57:45 +09:00
Kohya S
a33cad714e fix: error on batch generation closes #2209 2025-10-15 21:57:11 +09:00
3 changed files with 341 additions and 3 deletions

View File

@@ -1001,7 +1001,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
all_precomputed_text_data.append(text_data)
# Models should be removed from device after prepare_text_inputs
del tokenizer_batch, text_encoder_batch, temp_shared_models_txt, conds_cache_batch
del tokenizer_vlm, text_encoder_vlm_batch, tokenizer_byt5, text_encoder_byt5_batch, temp_shared_models_txt, conds_cache_batch
gc.collect() # Force cleanup of Text Encoder from GPU memory
clean_memory_on_device(device)
@@ -1075,7 +1075,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# save_output expects latent to be [BCTHW] or [CTHW]. generate returns [BCTHW] (batch size 1).
# latent[0] is correct if generate returns it with batch dim.
# The latent from generate is (1, C, T, H, W)
save_output(current_args, vae_for_batch, latent[0], device) # Pass vae_for_batch
save_output(current_args, vae_for_batch, latent, device) # Pass vae_for_batch
vae_for_batch.to("cpu") # Move VAE back to CPU

View File

@@ -2736,6 +2736,29 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
"""
from pathlib import Path
# Validate that latent caching is enabled
# CDC requires latents to be cached (either to disk or in memory) because:
# 1. CDC files are named based on latent cache filenames
# 2. CDC files are saved next to latent cache files
# 3. Training needs latent paths to load corresponding CDC files
has_cached_latents = False
for dataset in self.datasets:
for info in dataset.image_data.values():
if info.latents is not None or info.latents_npz is not None:
has_cached_latents = True
break
if has_cached_latents:
break
if not has_cached_latents:
raise ValueError(
"CDC-FM requires latent caching to be enabled. "
"Please enable latent caching by setting one of:\n"
" - cache_latents = true (cache in memory)\n"
" - cache_latents_to_disk = true (cache to disk)\n"
"in your training config or command line arguments."
)
# Collect dataset/subset directories for config hash
dataset_dirs = []
for dataset in self.datasets:
@@ -2851,9 +2874,39 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
# If latents_npz not set, we can't check for CDC cache
continue
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash)
# Compute expected latent shape from bucket_reso
# For multi-resolution CDC, we need to pass latent_shape to get the correct filename
latent_shape = None
if info.bucket_reso is not None:
# Get latent shape efficiently without loading full data
# First check if latent is already in memory
if info.latents is not None:
latent_shape = info.latents.shape
else:
# Load latent shape from npz file metadata
# This is faster than loading the full latent data
try:
import numpy as np
with np.load(info.latents_npz) as data:
# Find the key for this bucket resolution
# Multi-resolution format uses keys like "latents_104x80"
h, w = info.bucket_reso[1] // 8, info.bucket_reso[0] // 8
key = f"latents_{h}x{w}"
if key in data:
latent_shape = data[key].shape
elif 'latents' in data:
# Fallback for single-resolution cache
latent_shape = data['latents'].shape
except Exception as e:
logger.debug(f"Failed to read latent shape from {info.latents_npz}: {e}")
# Fall back to checking without shape (backward compatibility)
latent_shape = None
cdc_path = CDCPreprocessor.get_cdc_npz_path(info.latents_npz, config_hash, latent_shape)
if not Path(cdc_path).exists():
missing_count += 1
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Missing CDC cache: {cdc_path}")
if missing_count > 0:
logger.info(f"Found {missing_count}/{total_count} missing CDC cache files")

View File

@@ -0,0 +1,285 @@
"""
Test CDC cache detection with multi-resolution filenames
This test verifies that _check_cdc_caches_exist() correctly detects CDC cache files
that include resolution information in their filenames (e.g., image_flux_cdc_104x80_hash.npz).
This was a bug where the check was looking for files without resolution
(image_flux_cdc_hash.npz) while the actual files had resolution in the name.
"""
import os
import tempfile
import shutil
from pathlib import Path
import numpy as np
import pytest
from library.train_util import DatasetGroup, ImageInfo
from library.cdc_fm import CDCPreprocessor
class MockDataset:
"""Mock dataset for testing"""
def __init__(self, image_data):
self.image_data = image_data
self.image_dir = "/mock/dataset"
self.num_train_images = len(image_data)
self.num_reg_images = 0
def __len__(self):
return len(self.image_data)
def test_cdc_cache_detection_with_resolution():
"""
Test that CDC cache files with resolution in filename are properly detected.
This reproduces the bug where:
- CDC files are created with resolution: image_flux_cdc_104x80_hash.npz
- But check looked for: image_flux_cdc_hash.npz
- Result: Files not detected, unnecessary regeneration
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Setup: Create a mock latent cache file and corresponding CDC cache
config_hash = "test1234"
# Create latent cache file with multi-resolution format
latent_path = Path(tmpdir) / "image_0832x0640_flux.npz"
latent_shape = (16, 104, 80) # C, H, W for resolution 832x640 (832/8=104, 640/8=80)
# Save a mock latent file
np.savez(
latent_path,
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
)
# Create the CDC cache file with resolution in filename (as it's actually created)
cdc_path = CDCPreprocessor.get_cdc_npz_path(
str(latent_path),
config_hash,
latent_shape
)
# Verify the CDC path includes resolution
assert "104x80" in cdc_path, f"CDC path should include resolution: {cdc_path}"
# Create a mock CDC file
np.savez(
cdc_path,
eigenvectors=np.random.randn(8, 16*104*80).astype(np.float16),
eigenvalues=np.random.randn(8).astype(np.float16),
shape=np.array(latent_shape),
k_neighbors=256,
d_cdc=8,
gamma=1.0
)
# Setup mock dataset
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = str(latent_path)
image_info.bucket_reso = (640, 832) # W, H (note: reversed from latent shape H,W)
image_info.latents = None # Not in memory
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if CDC cache is detected
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return True since the CDC file exists
assert result is True, "CDC cache file should be detected when it exists with resolution in filename"
def test_cdc_cache_detection_missing_file():
"""
Test that missing CDC cache files are correctly identified as missing.
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_hash = "test5678"
# Create latent cache file but NO CDC cache
latent_path = Path(tmpdir) / "image_0768x0512_flux.npz"
latent_shape = (16, 96, 64) # C, H, W
np.savez(
latent_path,
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
)
# Setup mock dataset (CDC file does NOT exist)
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = str(latent_path)
image_info.bucket_reso = (512, 768) # W, H
image_info.latents = None
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if CDC cache is detected
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return False since CDC file doesn't exist
assert result is False, "Should detect that CDC cache file is missing"
def test_cdc_cache_detection_with_in_memory_latent():
"""
Test CDC cache detection when latent is already in memory (faster path).
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_hash = "test_mem1"
# Create latent cache file path (file may or may not exist)
latent_path = Path(tmpdir) / "image_1024x1024_flux.npz"
latent_shape = (16, 128, 128) # C, H, W
# Create the CDC cache file
cdc_path = CDCPreprocessor.get_cdc_npz_path(
str(latent_path),
config_hash,
latent_shape
)
np.savez(
cdc_path,
eigenvectors=np.random.randn(8, 16*128*128).astype(np.float16),
eigenvalues=np.random.randn(8).astype(np.float16),
shape=np.array(latent_shape),
k_neighbors=256,
d_cdc=8,
gamma=1.0
)
# Setup mock dataset with latent in memory
import torch
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = str(latent_path)
image_info.bucket_reso = (1024, 1024) # W, H
image_info.latents = torch.randn(latent_shape) # In memory!
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if CDC cache is detected (should use faster in-memory path)
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return True
assert result is True, "CDC cache should be detected using in-memory latent shape"
def test_cdc_cache_detection_partial_cache():
"""
Test that partial cache (some files exist, some don't) is correctly identified.
"""
with tempfile.TemporaryDirectory() as tmpdir:
config_hash = "testpart"
# Create two latent files
latent_path1 = Path(tmpdir) / "image1_0640x0512_flux.npz"
latent_path2 = Path(tmpdir) / "image2_0640x0512_flux.npz"
latent_shape = (16, 80, 64)
for latent_path in [latent_path1, latent_path2]:
np.savez(
latent_path,
**{f"latents_{latent_shape[1]}x{latent_shape[2]}": np.random.randn(*latent_shape).astype(np.float32)}
)
# Create CDC cache for ONLY the first image
cdc_path1 = CDCPreprocessor.get_cdc_npz_path(str(latent_path1), config_hash, latent_shape)
np.savez(
cdc_path1,
eigenvectors=np.random.randn(8, 16*80*64).astype(np.float16),
eigenvalues=np.random.randn(8).astype(np.float16),
shape=np.array(latent_shape),
k_neighbors=256,
d_cdc=8,
gamma=1.0
)
# CDC cache for second image does NOT exist
# Setup mock dataset with both images
info1 = ImageInfo("img1", 1, "test", False, str(Path(tmpdir) / "img1.png"))
info1.latents_npz = str(latent_path1)
info1.bucket_reso = (512, 640)
info1.latents = None
info2 = ImageInfo("img2", 1, "test", False, str(Path(tmpdir) / "img2.png"))
info2.latents_npz = str(latent_path2)
info2.bucket_reso = (512, 640)
info2.latents = None
mock_dataset = MockDataset({"img1": info1, "img2": info2})
dataset_group = DatasetGroup([mock_dataset])
# Test: Check if all CDC caches exist
result = dataset_group._check_cdc_caches_exist(config_hash)
# Verify: Should return False since not all files exist
assert result is False, "Should detect that some CDC cache files are missing"
def test_cdc_requires_latent_caching():
"""
Test that CDC-FM gives a clear error when latent caching is not enabled.
"""
with tempfile.TemporaryDirectory() as tmpdir:
# Setup mock dataset with NO latent caching (both latents and latents_npz are None)
image_info = ImageInfo(
image_key="test_image",
num_repeats=1,
caption="test",
is_reg=False,
absolute_path=str(Path(tmpdir) / "image.png")
)
image_info.latents_npz = None # No disk cache
image_info.latents = None # No memory cache
image_info.bucket_reso = (512, 512)
mock_dataset = MockDataset({"test_image": image_info})
dataset_group = DatasetGroup([mock_dataset])
# Test: Attempt to cache CDC without latent caching enabled
with pytest.raises(ValueError) as exc_info:
dataset_group.cache_cdc_gamma_b(
k_neighbors=256,
k_bandwidth=8,
d_cdc=8,
gamma=1.0
)
# Verify: Error message should mention latent caching requirement
error_message = str(exc_info.value)
assert "CDC-FM requires latent caching" in error_message
assert "cache_latents" in error_message
assert "cache_latents_to_disk" in error_message
if __name__ == "__main__":
# Run tests with verbose output
pytest.main([__file__, "-v"])