feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing

This commit is contained in:
Kohya S
2025-09-21 11:09:37 +09:00
parent 8f20c37949
commit f41e9e2b58
4 changed files with 185 additions and 56 deletions

View File

@@ -88,7 +88,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders")
parser.add_argument("--vae_enable_tiling", action="store_true", help="Enable tiling for VAE decoding")
parser.add_argument(
"--vae_chunk_size",
type=int,
default=None, # default is None (no chunking)
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNoneチャンクなし。有効にする場合は16程度を推奨。",
)
parser.add_argument(
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
)
@@ -431,14 +437,10 @@ def merge_lora_weights(
# endregion
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, enable_tiling: bool = False) -> torch.Tensor:
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device) -> torch.Tensor:
logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}")
vae.to(device)
if enable_tiling:
vae.enable_tiling()
else:
vae.disable_tiling()
with torch.no_grad():
latent = latent / vae.scaling_factor # scale latent back to original range
pixels = vae.decode(latent.to(device, dtype=vae.dtype))
@@ -807,7 +809,7 @@ def save_output(
vae: HunyuanVAE2D,
latent: torch.Tensor,
device: torch.device,
original_base_names: Optional[List[str]] = None,
original_base_name: Optional[str] = None,
) -> None:
"""save output
@@ -816,7 +818,7 @@ def save_output(
vae: VAE model
latent: latent tensor
device: device to use
original_base_names: original base names (if latents are loaded from files)
original_base_name: original base name (if latents are loaded from files)
"""
height, width = latent.shape[-2], latent.shape[-1] # BCTHW
height *= hunyuan_image_vae.VAE_SCALE_FACTOR
@@ -839,14 +841,14 @@ def save_output(
1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR
)
image = decode_latent(vae, latent, device, args.vae_enable_tiling)
image = decode_latent(vae, latent, device)
if args.output_type == "images" or args.output_type == "latent_images":
# save images
if original_base_names is None or len(original_base_names) == 0:
if original_base_name is None:
original_name = ""
else:
original_name = f"_{original_base_names[0]}"
original_name = f"_{original_base_name}"
save_images(image, args, original_name)
@@ -919,7 +921,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
# 1. Prepare VAE
logger.info("Loading VAE for batch generation...")
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
vae_for_batch.eval()
all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
@@ -1057,7 +1059,7 @@ def process_interactive(args: argparse.Namespace) -> None:
shared_models = load_shared_models(args)
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
vae.eval()
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
@@ -1185,9 +1187,9 @@ def main():
for i, latent in enumerate(latents_list):
args.seed = seeds[i]
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True)
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True, chunk_size=args.vae_chunk_size)
vae.eval()
save_output(args, vae, latent, device, original_base_names)
save_output(args, vae, latent, device, original_base_names[i])
elif args.from_file:
# Batch mode from file
@@ -1220,7 +1222,7 @@ def main():
clean_memory_on_device(device)
# Save latent and video
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
vae.eval()
save_output(args, vae, latent, device)