mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing
This commit is contained in:
@@ -88,7 +88,13 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
||||
|
||||
parser.add_argument("--text_encoder_cpu", action="store_true", help="Inference on CPU for Text Encoders")
|
||||
parser.add_argument("--vae_enable_tiling", action="store_true", help="Enable tiling for VAE decoding")
|
||||
parser.add_argument(
|
||||
"--vae_chunk_size",
|
||||
type=int,
|
||||
default=None, # default is None (no chunking)
|
||||
help="Chunk size for VAE decoding to reduce memory usage. Default is None (no chunking). 16 is recommended if enabled"
|
||||
" / メモリ使用量を減らすためのVAEデコードのチャンクサイズ。デフォルトはNone(チャンクなし)。有効にする場合は16程度を推奨。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
||||
)
|
||||
@@ -431,14 +437,10 @@ def merge_lora_weights(
|
||||
# endregion
|
||||
|
||||
|
||||
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device, enable_tiling: bool = False) -> torch.Tensor:
|
||||
def decode_latent(vae: HunyuanVAE2D, latent: torch.Tensor, device: torch.device) -> torch.Tensor:
|
||||
logger.info(f"Decoding image. Latent shape {latent.shape}, device {device}")
|
||||
|
||||
vae.to(device)
|
||||
if enable_tiling:
|
||||
vae.enable_tiling()
|
||||
else:
|
||||
vae.disable_tiling()
|
||||
with torch.no_grad():
|
||||
latent = latent / vae.scaling_factor # scale latent back to original range
|
||||
pixels = vae.decode(latent.to(device, dtype=vae.dtype))
|
||||
@@ -807,7 +809,7 @@ def save_output(
|
||||
vae: HunyuanVAE2D,
|
||||
latent: torch.Tensor,
|
||||
device: torch.device,
|
||||
original_base_names: Optional[List[str]] = None,
|
||||
original_base_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""save output
|
||||
|
||||
@@ -816,7 +818,7 @@ def save_output(
|
||||
vae: VAE model
|
||||
latent: latent tensor
|
||||
device: device to use
|
||||
original_base_names: original base names (if latents are loaded from files)
|
||||
original_base_name: original base name (if latents are loaded from files)
|
||||
"""
|
||||
height, width = latent.shape[-2], latent.shape[-1] # BCTHW
|
||||
height *= hunyuan_image_vae.VAE_SCALE_FACTOR
|
||||
@@ -839,14 +841,14 @@ def save_output(
|
||||
1, vae.latent_channels, height // hunyuan_image_vae.VAE_SCALE_FACTOR, width // hunyuan_image_vae.VAE_SCALE_FACTOR
|
||||
)
|
||||
|
||||
image = decode_latent(vae, latent, device, args.vae_enable_tiling)
|
||||
image = decode_latent(vae, latent, device)
|
||||
|
||||
if args.output_type == "images" or args.output_type == "latent_images":
|
||||
# save images
|
||||
if original_base_names is None or len(original_base_names) == 0:
|
||||
if original_base_name is None:
|
||||
original_name = ""
|
||||
else:
|
||||
original_name = f"_{original_base_names[0]}"
|
||||
original_name = f"_{original_base_name}"
|
||||
save_images(image, args, original_name)
|
||||
|
||||
|
||||
@@ -919,7 +921,7 @@ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) ->
|
||||
|
||||
# 1. Prepare VAE
|
||||
logger.info("Loading VAE for batch generation...")
|
||||
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae_for_batch = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae_for_batch.eval()
|
||||
|
||||
all_prompt_args_list = [apply_overrides(args, pd) for pd in prompts_data] # Create all arg instances first
|
||||
@@ -1057,7 +1059,7 @@ def process_interactive(args: argparse.Namespace) -> None:
|
||||
shared_models = load_shared_models(args)
|
||||
shared_models["conds_cache"] = {} # Initialize empty cache for interactive mode
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae.eval()
|
||||
|
||||
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
|
||||
@@ -1185,9 +1187,9 @@ def main():
|
||||
for i, latent in enumerate(latents_list):
|
||||
args.seed = seeds[i]
|
||||
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True)
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device=device, disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device, original_base_names)
|
||||
save_output(args, vae, latent, device, original_base_names[i])
|
||||
|
||||
elif args.from_file:
|
||||
# Batch mode from file
|
||||
@@ -1220,7 +1222,7 @@ def main():
|
||||
clean_memory_on_device(device)
|
||||
|
||||
# Save latent and video
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True)
|
||||
vae = hunyuan_image_vae.load_vae(args.vae, device="cpu", disable_mmap=True, chunk_size=args.vae_chunk_size)
|
||||
vae.eval()
|
||||
save_output(args, vae, latent, device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user