mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
756 lines
28 KiB
Python
756 lines
28 KiB
Python
from typing import Optional, Tuple
|
|
|
|
from einops import rearrange
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor, nn
|
|
from torch.nn import Conv2d
|
|
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
|
|
|
from library.safetensors_utils import load_safetensors
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
VAE_SCALE_FACTOR = 32 # 32x spatial compression
|
|
|
|
LATENT_SCALING_FACTOR = 0.75289 # Latent scaling factor for Hunyuan Image-2.1
|
|
|
|
|
|
def swish(x: Tensor) -> Tensor:
|
|
"""Swish activation function: x * sigmoid(x)."""
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
|
class AttnBlock(nn.Module):
|
|
"""Self-attention block using scaled dot-product attention."""
|
|
|
|
def __init__(self, in_channels: int, chunk_size: Optional[int] = None):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
|
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
|
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
|
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
|
else:
|
|
self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
|
self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
|
self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
|
self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
|
|
|
def attention(self, x: Tensor) -> Tensor:
|
|
x = self.norm(x)
|
|
q = self.q(x)
|
|
k = self.k(x)
|
|
v = self.v(x)
|
|
|
|
b, c, h, w = q.shape
|
|
q = rearrange(q, "b c h w -> b (h w) c").contiguous()
|
|
k = rearrange(k, "b c h w -> b (h w) c").contiguous()
|
|
v = rearrange(v, "b c h w -> b (h w) c").contiguous()
|
|
|
|
x = nn.functional.scaled_dot_product_attention(q, k, v)
|
|
return rearrange(x, "b (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return x + self.proj_out(self.attention(x))
|
|
|
|
|
|
class ChunkedConv2d(nn.Conv2d):
|
|
"""
|
|
Convolutional layer that processes input in chunks to reduce memory usage.
|
|
|
|
Parameters
|
|
----------
|
|
chunk_size : int, optional
|
|
Size of chunks to process at a time. Default is 64.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
if "chunk_size" in kwargs:
|
|
self.chunk_size = kwargs.pop("chunk_size", 64)
|
|
super().__init__(*args, **kwargs)
|
|
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
|
|
assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported."
|
|
assert self.groups == 1, "Only groups=1 is supported."
|
|
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
|
|
assert (
|
|
self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2
|
|
), "Only kernel_size//2 padding is supported."
|
|
self.original_padding = self.padding
|
|
self.padding = (0, 0) # We handle padding manually in forward
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# If chunking is not needed, process normally. We chunk only along height dimension.
|
|
if self.chunk_size is None or x.shape[1] <= self.chunk_size:
|
|
self.padding = self.original_padding
|
|
x = super().forward(x)
|
|
self.padding = (0, 0)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
return x
|
|
|
|
# Process input in chunks to reduce memory usage
|
|
org_shape = x.shape
|
|
|
|
# If kernel size is not 1, we need to use overlapping chunks
|
|
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
|
|
step = self.chunk_size - overlap
|
|
y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device)
|
|
yi = 0
|
|
i = 0
|
|
while i < org_shape[2]:
|
|
si = i if i == 0 else i - overlap
|
|
ei = i + self.chunk_size
|
|
|
|
# Check last chunk. If remaining part is small, include it in last chunk
|
|
if ei > org_shape[2] or ei + step // 4 > org_shape[2]:
|
|
ei = org_shape[2]
|
|
|
|
chunk = x[:, :, : ei - si, :]
|
|
x = x[:, :, ei - si - overlap * 2 :, :]
|
|
|
|
# Pad chunk if needed: This is as the original Conv2d with padding
|
|
if i == 0: # First chunk
|
|
# Pad except bottom
|
|
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
|
|
elif ei == org_shape[2]: # Last chunk
|
|
# Pad except top
|
|
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
|
|
else:
|
|
# Pad left and right only
|
|
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
|
|
|
|
chunk = super().forward(chunk)
|
|
y[:, :, yi : yi + chunk.shape[2], :] = chunk
|
|
yi += chunk.shape[2]
|
|
del chunk
|
|
|
|
if ei == org_shape[2]:
|
|
break
|
|
i += step
|
|
|
|
assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}"
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit
|
|
return y
|
|
|
|
|
|
class ResnetBlock(nn.Module):
|
|
"""
|
|
Residual block with two convolutions, group normalization, and swish activation.
|
|
Includes skip connection with optional channel dimension matching.
|
|
|
|
Parameters
|
|
----------
|
|
in_channels : int
|
|
Number of input channels.
|
|
out_channels : int
|
|
Number of output channels.
|
|
"""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
|
super().__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
|
|
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
|
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
# Skip connection projection for channel dimension mismatch
|
|
if self.in_channels != self.out_channels:
|
|
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
|
else:
|
|
self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
|
self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
|
|
|
# Skip connection projection for channel dimension mismatch
|
|
if self.in_channels != self.out_channels:
|
|
self.nin_shortcut = ChunkedConv2d(
|
|
in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size
|
|
)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
h = x
|
|
# First convolution block
|
|
h = self.norm1(h)
|
|
h = swish(h)
|
|
h = self.conv1(h)
|
|
# Second convolution block
|
|
h = self.norm2(h)
|
|
h = swish(h)
|
|
h = self.conv2(h)
|
|
|
|
# Apply skip connection with optional projection
|
|
if self.in_channels != self.out_channels:
|
|
x = self.nin_shortcut(x)
|
|
return x + h
|
|
|
|
|
|
class Downsample(nn.Module):
|
|
"""
|
|
Spatial downsampling block that reduces resolution by 2x using convolution followed by
|
|
pixel rearrangement. Includes skip connection with grouped averaging.
|
|
|
|
Parameters
|
|
----------
|
|
in_channels : int
|
|
Number of input channels.
|
|
out_channels : int
|
|
Number of output channels (must be divisible by 4).
|
|
"""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
|
super().__init__()
|
|
factor = 4 # 2x2 spatial reduction factor
|
|
assert out_channels % factor == 0
|
|
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv = ChunkedConv2d(
|
|
in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
|
|
)
|
|
self.group_size = factor * in_channels // out_channels
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# Apply convolution and rearrange pixels for 2x downsampling
|
|
h = self.conv(x)
|
|
h = rearrange(h, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
|
|
|
|
# Create skip connection with pixel rearrangement
|
|
shortcut = rearrange(x, "b c (h r1) (w r2) -> b (r1 r2 c) h w", r1=2, r2=2)
|
|
B, C, H, W = shortcut.shape
|
|
shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
|
|
|
|
return h + shortcut
|
|
|
|
|
|
class Upsample(nn.Module):
|
|
"""
|
|
Spatial upsampling block that increases resolution by 2x using convolution followed by
|
|
pixel rearrangement. Includes skip connection with channel repetition.
|
|
|
|
Parameters
|
|
----------
|
|
in_channels : int
|
|
Number of input channels.
|
|
out_channels : int
|
|
Number of output channels.
|
|
"""
|
|
|
|
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
|
super().__init__()
|
|
factor = 4 # 2x2 spatial expansion factor
|
|
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
|
|
|
self.repeats = factor * out_channels // in_channels
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# Apply convolution and rearrange pixels for 2x upsampling
|
|
h = self.conv(x)
|
|
h = rearrange(h, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
|
|
|
|
# Create skip connection with channel repetition
|
|
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
|
|
shortcut = rearrange(shortcut, "b (r1 r2 c) h w -> b c (h r1) (w r2)", r1=2, r2=2)
|
|
|
|
return h + shortcut
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""
|
|
VAE encoder that progressively downsamples input images to a latent representation.
|
|
Uses residual blocks, attention, and spatial downsampling.
|
|
|
|
Parameters
|
|
----------
|
|
in_channels : int
|
|
Number of input image channels (e.g., 3 for RGB).
|
|
z_channels : int
|
|
Number of latent channels in the output.
|
|
block_out_channels : Tuple[int, ...]
|
|
Output channels for each downsampling block.
|
|
num_res_blocks : int
|
|
Number of residual blocks per downsampling stage.
|
|
ffactor_spatial : int
|
|
Total spatial downsampling factor (e.g., 32 for 32x compression).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
z_channels: int,
|
|
block_out_channels: Tuple[int, ...],
|
|
num_res_blocks: int,
|
|
ffactor_spatial: int,
|
|
chunk_size: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
assert block_out_channels[-1] % (2 * z_channels) == 0
|
|
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv_in = ChunkedConv2d(
|
|
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
|
|
)
|
|
|
|
self.down = nn.ModuleList()
|
|
block_in = block_out_channels[0]
|
|
|
|
# Build downsampling blocks
|
|
for i_level, ch in enumerate(block_out_channels):
|
|
block = nn.ModuleList()
|
|
block_out = ch
|
|
|
|
# Add residual blocks for this level
|
|
for _ in range(self.num_res_blocks):
|
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
|
|
block_in = block_out
|
|
|
|
down = nn.Module()
|
|
down.block = block
|
|
|
|
# Add spatial downsampling if needed
|
|
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
|
|
if add_spatial_downsample:
|
|
assert i_level < len(block_out_channels) - 1
|
|
block_out = block_out_channels[i_level + 1]
|
|
down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size)
|
|
block_in = block_out
|
|
|
|
self.down.append(down)
|
|
|
|
# Middle blocks with attention
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
|
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
|
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
|
|
|
# Output layers
|
|
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
# Initial convolution
|
|
h = self.conv_in(x)
|
|
|
|
# Progressive downsampling through blocks
|
|
for i_level in range(len(self.block_out_channels)):
|
|
# Apply residual blocks at this level
|
|
for i_block in range(self.num_res_blocks):
|
|
h = self.down[i_level].block[i_block](h)
|
|
# Apply spatial downsampling if available
|
|
if hasattr(self.down[i_level], "downsample"):
|
|
h = self.down[i_level].downsample(h)
|
|
|
|
# Middle processing with attention
|
|
h = self.mid.block_1(h)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h)
|
|
|
|
# Final output layers with skip connection
|
|
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
|
|
shortcut = rearrange(h, "b (c r) h w -> b c r h w", r=group_size).mean(dim=2)
|
|
h = self.norm_out(h)
|
|
h = swish(h)
|
|
h = self.conv_out(h)
|
|
h += shortcut
|
|
return h
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
"""
|
|
VAE decoder that progressively upsamples latent representations back to images.
|
|
Uses residual blocks, attention, and spatial upsampling.
|
|
|
|
Parameters
|
|
----------
|
|
z_channels : int
|
|
Number of latent channels in the input.
|
|
out_channels : int
|
|
Number of output image channels (e.g., 3 for RGB).
|
|
block_out_channels : Tuple[int, ...]
|
|
Output channels for each upsampling block.
|
|
num_res_blocks : int
|
|
Number of residual blocks per upsampling stage.
|
|
ffactor_spatial : int
|
|
Total spatial upsampling factor (e.g., 32 for 32x expansion).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
z_channels: int,
|
|
out_channels: int,
|
|
block_out_channels: Tuple[int, ...],
|
|
num_res_blocks: int,
|
|
ffactor_spatial: int,
|
|
chunk_size: Optional[int] = None,
|
|
):
|
|
super().__init__()
|
|
assert block_out_channels[0] % z_channels == 0
|
|
|
|
self.z_channels = z_channels
|
|
self.block_out_channels = block_out_channels
|
|
self.num_res_blocks = num_res_blocks
|
|
|
|
block_in = block_out_channels[0]
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
|
|
|
# Middle blocks with attention
|
|
self.mid = nn.Module()
|
|
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
|
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
|
|
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
|
|
|
# Build upsampling blocks
|
|
self.up = nn.ModuleList()
|
|
for i_level, ch in enumerate(block_out_channels):
|
|
block = nn.ModuleList()
|
|
block_out = ch
|
|
|
|
# Add residual blocks for this level (extra block for decoder)
|
|
for _ in range(self.num_res_blocks + 1):
|
|
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
|
|
block_in = block_out
|
|
|
|
up = nn.Module()
|
|
up.block = block
|
|
|
|
# Add spatial upsampling if needed
|
|
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
|
|
if add_spatial_upsample:
|
|
assert i_level < len(block_out_channels) - 1
|
|
block_out = block_out_channels[i_level + 1]
|
|
up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size)
|
|
block_in = block_out
|
|
|
|
self.up.append(up)
|
|
|
|
# Output layers
|
|
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
|
if chunk_size is None or chunk_size <= 0:
|
|
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
|
else:
|
|
self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
|
|
|
def forward(self, z: Tensor) -> Tensor:
|
|
# Initial processing with skip connection
|
|
repeats = self.block_out_channels[0] // self.z_channels
|
|
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
|
|
|
# Middle processing with attention
|
|
h = self.mid.block_1(h)
|
|
h = self.mid.attn_1(h)
|
|
h = self.mid.block_2(h)
|
|
|
|
# Progressive upsampling through blocks
|
|
for i_level in range(len(self.block_out_channels)):
|
|
# Apply residual blocks at this level
|
|
for i_block in range(self.num_res_blocks + 1):
|
|
h = self.up[i_level].block[i_block](h)
|
|
# Apply spatial upsampling if available
|
|
if hasattr(self.up[i_level], "upsample"):
|
|
h = self.up[i_level].upsample(h)
|
|
|
|
# Final output layers
|
|
h = self.norm_out(h)
|
|
h = swish(h)
|
|
h = self.conv_out(h)
|
|
return h
|
|
|
|
|
|
class HunyuanVAE2D(nn.Module):
|
|
"""
|
|
VAE model for Hunyuan Image-2.1 with spatial tiling support.
|
|
|
|
This VAE uses a fixed architecture optimized for the Hunyuan Image-2.1 model,
|
|
with 32x spatial compression and optional memory-efficient tiling for large images.
|
|
"""
|
|
|
|
def __init__(self, chunk_size: Optional[int] = None):
|
|
super().__init__()
|
|
|
|
# Fixed configuration for Hunyuan Image-2.1
|
|
block_out_channels = (128, 256, 512, 512, 1024, 1024)
|
|
in_channels = 3 # RGB input
|
|
out_channels = 3 # RGB output
|
|
latent_channels = 64
|
|
layers_per_block = 2
|
|
ffactor_spatial = 32 # 32x spatial compression
|
|
sample_size = 384 # Minimum sample size for tiling
|
|
scaling_factor = LATENT_SCALING_FACTOR # 0.75289 # Latent scaling factor
|
|
|
|
self.ffactor_spatial = ffactor_spatial
|
|
self.scaling_factor = scaling_factor
|
|
|
|
self.encoder = Encoder(
|
|
in_channels=in_channels,
|
|
z_channels=latent_channels,
|
|
block_out_channels=block_out_channels,
|
|
num_res_blocks=layers_per_block,
|
|
ffactor_spatial=ffactor_spatial,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
self.decoder = Decoder(
|
|
z_channels=latent_channels,
|
|
out_channels=out_channels,
|
|
block_out_channels=list(reversed(block_out_channels)),
|
|
num_res_blocks=layers_per_block,
|
|
ffactor_spatial=ffactor_spatial,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
# Spatial tiling configuration for memory efficiency
|
|
self.use_spatial_tiling = False
|
|
self.tile_sample_min_size = sample_size
|
|
self.tile_latent_min_size = sample_size // ffactor_spatial
|
|
self.tile_overlap_factor = 0.25 # 25% overlap between tiles
|
|
|
|
@property
|
|
def dtype(self):
|
|
"""Get the data type of the model parameters."""
|
|
return next(self.encoder.parameters()).dtype
|
|
|
|
@property
|
|
def device(self):
|
|
"""Get the device of the model parameters."""
|
|
return next(self.encoder.parameters()).device
|
|
|
|
def enable_spatial_tiling(self, use_tiling: bool = True):
|
|
"""Enable or disable spatial tiling."""
|
|
self.use_spatial_tiling = use_tiling
|
|
|
|
def disable_spatial_tiling(self):
|
|
"""Disable spatial tiling."""
|
|
self.use_spatial_tiling = False
|
|
|
|
def enable_tiling(self, use_tiling: bool = True):
|
|
"""Enable or disable spatial tiling (alias for enable_spatial_tiling)."""
|
|
self.enable_spatial_tiling(use_tiling)
|
|
|
|
def disable_tiling(self):
|
|
"""Disable spatial tiling (alias for disable_spatial_tiling)."""
|
|
self.disable_spatial_tiling()
|
|
|
|
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
"""
|
|
Blend two tensors horizontally with smooth transition.
|
|
|
|
Parameters
|
|
----------
|
|
a : torch.Tensor
|
|
Left tensor.
|
|
b : torch.Tensor
|
|
Right tensor.
|
|
blend_extent : int
|
|
Number of columns to blend.
|
|
"""
|
|
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
|
for x in range(blend_extent):
|
|
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
|
return b
|
|
|
|
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
|
"""
|
|
Blend two tensors vertically with smooth transition.
|
|
|
|
Parameters
|
|
----------
|
|
a : torch.Tensor
|
|
Top tensor.
|
|
b : torch.Tensor
|
|
Bottom tensor.
|
|
blend_extent : int
|
|
Number of rows to blend.
|
|
"""
|
|
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
|
for y in range(blend_extent):
|
|
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
|
return b
|
|
|
|
def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Encode large images using spatial tiling to reduce memory usage.
|
|
Tiles are processed independently and blended at boundaries.
|
|
|
|
Parameters
|
|
----------
|
|
x : torch.Tensor
|
|
Input tensor of shape (B, C, T, H, W) or (B, C, H, W).
|
|
"""
|
|
# Handle 5D input (B, C, T, H, W) by removing time dimension
|
|
original_ndim = x.ndim
|
|
if original_ndim == 5:
|
|
x = x.squeeze(2)
|
|
|
|
B, C, H, W = x.shape
|
|
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
|
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
|
row_limit = self.tile_latent_min_size - blend_extent
|
|
|
|
rows = []
|
|
for i in range(0, H, overlap_size):
|
|
row = []
|
|
for j in range(0, W, overlap_size):
|
|
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
|
tile = self.encoder(tile)
|
|
row.append(tile)
|
|
rows.append(row)
|
|
|
|
result_rows = []
|
|
for i, row in enumerate(rows):
|
|
result_row = []
|
|
for j, tile in enumerate(row):
|
|
if i > 0:
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
if j > 0:
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
result_row.append(tile[:, :, :row_limit, :row_limit])
|
|
result_rows.append(torch.cat(result_row, dim=-1))
|
|
|
|
moments = torch.cat(result_rows, dim=-2)
|
|
return moments
|
|
|
|
def spatial_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
Decode large latents using spatial tiling to reduce memory usage.
|
|
Tiles are processed independently and blended at boundaries.
|
|
|
|
Parameters
|
|
----------
|
|
z : torch.Tensor
|
|
Latent tensor of shape (B, C, H, W).
|
|
"""
|
|
B, C, H, W = z.shape
|
|
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
|
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
|
row_limit = self.tile_sample_min_size - blend_extent
|
|
|
|
rows = []
|
|
for i in range(0, H, overlap_size):
|
|
row = []
|
|
for j in range(0, W, overlap_size):
|
|
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
|
decoded = self.decoder(tile)
|
|
row.append(decoded)
|
|
rows.append(row)
|
|
|
|
result_rows = []
|
|
for i, row in enumerate(rows):
|
|
result_row = []
|
|
for j, tile in enumerate(row):
|
|
if i > 0:
|
|
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
|
if j > 0:
|
|
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
|
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
|
result_rows.append(torch.cat(result_row, dim=-1))
|
|
|
|
dec = torch.cat(result_rows, dim=-2)
|
|
return dec
|
|
|
|
def encode(self, x: Tensor) -> DiagonalGaussianDistribution:
|
|
"""
|
|
Encode input images to latent representation.
|
|
Uses spatial tiling for large images if enabled.
|
|
|
|
Parameters
|
|
----------
|
|
x : Tensor
|
|
Input image tensor of shape (B, C, H, W) or (B, C, T, H, W).
|
|
|
|
Returns
|
|
-------
|
|
DiagonalGaussianDistribution
|
|
Latent distribution with mean and logvar.
|
|
"""
|
|
# Handle 5D input (B, C, T, H, W) by removing time dimension
|
|
original_ndim = x.ndim
|
|
if original_ndim == 5:
|
|
x = x.squeeze(2)
|
|
|
|
# Use tiling for large images to reduce memory usage
|
|
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
|
h = self.spatial_tiled_encode(x)
|
|
else:
|
|
h = self.encoder(x)
|
|
|
|
# Restore time dimension if input was 5D
|
|
if original_ndim == 5:
|
|
h = h.unsqueeze(2)
|
|
|
|
posterior = DiagonalGaussianDistribution(h)
|
|
return posterior
|
|
|
|
def decode(self, z: Tensor):
|
|
"""
|
|
Decode latent representation back to images.
|
|
Uses spatial tiling for large latents if enabled.
|
|
|
|
Parameters
|
|
----------
|
|
z : Tensor
|
|
Latent tensor of shape (B, C, H, W) or (B, C, T, H, W).
|
|
|
|
Returns
|
|
-------
|
|
Tensor
|
|
Decoded image tensor.
|
|
"""
|
|
# Handle 5D input (B, C, T, H, W) by removing time dimension
|
|
original_ndim = z.ndim
|
|
if original_ndim == 5:
|
|
z = z.squeeze(2)
|
|
|
|
# Use tiling for large latents to reduce memory usage
|
|
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
|
decoded = self.spatial_tiled_decode(z)
|
|
else:
|
|
decoded = self.decoder(z)
|
|
|
|
# Restore time dimension if input was 5D
|
|
if original_ndim == 5:
|
|
decoded = decoded.unsqueeze(2)
|
|
|
|
return decoded
|
|
|
|
|
|
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D:
|
|
logger.info(f"Initializing VAE with chunk_size={chunk_size}")
|
|
vae = HunyuanVAE2D(chunk_size=chunk_size)
|
|
|
|
logger.info(f"Loading VAE from {vae_path}")
|
|
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
|
|
info = vae.load_state_dict(state_dict, strict=True, assign=True)
|
|
logger.info(f"Loaded VAE: {info}")
|
|
|
|
vae.to(device)
|
|
return vae
|