mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
feat: add vae_chunk_size argument for memory-efficient VAE decoding and processing
This commit is contained in:
@@ -29,14 +29,20 @@ def swish(x: Tensor) -> Tensor:
|
||||
class AttnBlock(nn.Module):
|
||||
"""Self-attention block using scaled dot-product attention."""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
def __init__(self, in_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.q = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.k = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.v = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, kernel_size=1)
|
||||
else:
|
||||
self.q = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
self.k = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
self.v = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
self.proj_out = ChunkedConv2d(in_channels, in_channels, kernel_size=1, chunk_size=chunk_size)
|
||||
|
||||
def attention(self, x: Tensor) -> Tensor:
|
||||
x = self.norm(x)
|
||||
@@ -56,6 +62,87 @@ class AttnBlock(nn.Module):
|
||||
return x + self.proj_out(self.attention(x))
|
||||
|
||||
|
||||
class ChunkedConv2d(nn.Conv2d):
|
||||
"""
|
||||
Convolutional layer that processes input in chunks to reduce memory usage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
chunk_size : int, optional
|
||||
Size of chunks to process at a time. Default is 64.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "chunk_size" in kwargs:
|
||||
self.chunk_size = kwargs.pop("chunk_size", 64)
|
||||
super().__init__(*args, **kwargs)
|
||||
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
|
||||
assert self.dilation == (1, 1) and self.stride == (1, 1), "Only dilation=1 and stride=1 are supported."
|
||||
assert self.groups == 1, "Only groups=1 is supported."
|
||||
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
|
||||
assert (
|
||||
self.padding[0] == self.padding[1] and self.padding[0] == self.kernel_size[0] // 2
|
||||
), "Only kernel_size//2 padding is supported."
|
||||
self.original_padding = self.padding
|
||||
self.padding = (0, 0) # We handle padding manually in forward
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# If chunking is not needed, process normally. We chunk only along height dimension.
|
||||
if self.chunk_size is None or x.shape[1] <= self.chunk_size:
|
||||
self.padding = self.original_padding
|
||||
x = super().forward(x)
|
||||
self.padding = (0, 0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return x
|
||||
|
||||
# Process input in chunks to reduce memory usage
|
||||
org_shape = x.shape
|
||||
|
||||
# If kernel size is not 1, we need to use overlapping chunks
|
||||
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
|
||||
step = self.chunk_size - overlap
|
||||
y = torch.zeros((org_shape[0], self.out_channels, org_shape[2], org_shape[3]), dtype=x.dtype, device=x.device)
|
||||
yi = 0
|
||||
i = 0
|
||||
while i < org_shape[2]:
|
||||
si = i if i == 0 else i - overlap
|
||||
ei = i + self.chunk_size
|
||||
|
||||
# Check last chunk. If remaining part is small, include it in last chunk
|
||||
if ei > org_shape[2] or ei + step // 4 > org_shape[2]:
|
||||
ei = org_shape[2]
|
||||
|
||||
chunk = x[:, :, : ei - si, :]
|
||||
x = x[:, :, ei - si - overlap * 2 :, :]
|
||||
|
||||
# Pad chunk if needed: This is as the original Conv2d with padding
|
||||
if i == 0: # First chunk
|
||||
# Pad except bottom
|
||||
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
|
||||
elif ei == org_shape[2]: # Last chunk
|
||||
# Pad except top
|
||||
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
|
||||
else:
|
||||
# Pad left and right only
|
||||
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
|
||||
|
||||
chunk = super().forward(chunk)
|
||||
y[:, :, yi : yi + chunk.shape[2], :] = chunk
|
||||
yi += chunk.shape[2]
|
||||
del chunk
|
||||
|
||||
if ei == org_shape[2]:
|
||||
break
|
||||
i += step
|
||||
|
||||
assert yi == org_shape[2], f"yi={yi}, org_shape[2]={org_shape[2]}"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache() # This helps reduce peak memory usage, but slows down a bit
|
||||
return y
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
"""
|
||||
Residual block with two convolutions, group normalization, and swish activation.
|
||||
@@ -69,19 +156,29 @@ class ResnetBlock(nn.Module):
|
||||
Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv1 = Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv1 = ChunkedConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
self.conv2 = ChunkedConv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
# Skip connection projection for channel dimension mismatch
|
||||
if self.in_channels != self.out_channels:
|
||||
self.nin_shortcut = ChunkedConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
h = x
|
||||
@@ -113,12 +210,17 @@ class Downsample(nn.Module):
|
||||
Number of output channels (must be divisible by 4).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
factor = 4 # 2x2 spatial reduction factor
|
||||
assert out_channels % factor == 0
|
||||
|
||||
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv = Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv = ChunkedConv2d(
|
||||
in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
|
||||
)
|
||||
self.group_size = factor * in_channels // out_channels
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@@ -147,10 +249,15 @@ class Upsample(nn.Module):
|
||||
Number of output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int):
|
||||
def __init__(self, in_channels: int, out_channels: int, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
factor = 4 # 2x2 spatial expansion factor
|
||||
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv = Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv = ChunkedConv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
self.repeats = factor * out_channels // in_channels
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
@@ -191,6 +298,7 @@ class Encoder(nn.Module):
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
ffactor_spatial: int,
|
||||
chunk_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_out_channels[-1] % (2 * z_channels) == 0
|
||||
@@ -199,7 +307,12 @@ class Encoder(nn.Module):
|
||||
self.block_out_channels = block_out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_in = Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_in = ChunkedConv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = block_out_channels[0]
|
||||
@@ -211,7 +324,7 @@ class Encoder(nn.Module):
|
||||
|
||||
# Add residual blocks for this level
|
||||
for _ in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
|
||||
block_in = block_out
|
||||
|
||||
down = nn.Module()
|
||||
@@ -222,20 +335,23 @@ class Encoder(nn.Module):
|
||||
if add_spatial_downsample:
|
||||
assert i_level < len(block_out_channels) - 1
|
||||
block_out = block_out_channels[i_level + 1]
|
||||
down.downsample = Downsample(block_in, block_out)
|
||||
down.downsample = Downsample(block_in, block_out, chunk_size=chunk_size)
|
||||
block_in = block_out
|
||||
|
||||
self.down.append(down)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_out = Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_out = ChunkedConv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# Initial convolution
|
||||
@@ -291,6 +407,7 @@ class Decoder(nn.Module):
|
||||
block_out_channels: Tuple[int, ...],
|
||||
num_res_blocks: int,
|
||||
ffactor_spatial: int,
|
||||
chunk_size: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_out_channels[0] % z_channels == 0
|
||||
@@ -300,13 +417,16 @@ class Decoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
|
||||
block_in = block_out_channels[0]
|
||||
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_in = Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_in = ChunkedConv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
# Middle blocks with attention
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.attn_1 = AttnBlock(block_in)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
self.mid.attn_1 = AttnBlock(block_in, chunk_size=chunk_size)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, chunk_size=chunk_size)
|
||||
|
||||
# Build upsampling blocks
|
||||
self.up = nn.ModuleList()
|
||||
@@ -316,7 +436,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# Add residual blocks for this level (extra block for decoder)
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
||||
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, chunk_size=chunk_size))
|
||||
block_in = block_out
|
||||
|
||||
up = nn.Module()
|
||||
@@ -327,14 +447,17 @@ class Decoder(nn.Module):
|
||||
if add_spatial_upsample:
|
||||
assert i_level < len(block_out_channels) - 1
|
||||
block_out = block_out_channels[i_level + 1]
|
||||
up.upsample = Upsample(block_in, block_out)
|
||||
up.upsample = Upsample(block_in, block_out, chunk_size=chunk_size)
|
||||
block_in = block_out
|
||||
|
||||
self.up.append(up)
|
||||
|
||||
# Output layers
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
||||
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if chunk_size is None or chunk_size <= 0:
|
||||
self.conv_out = Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.conv_out = ChunkedConv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1, chunk_size=chunk_size)
|
||||
|
||||
def forward(self, z: Tensor) -> Tensor:
|
||||
# Initial processing with skip connection
|
||||
@@ -370,7 +493,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
with 32x spatial compression and optional memory-efficient tiling for large images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, chunk_size: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
# Fixed configuration for Hunyuan Image-2.1
|
||||
@@ -392,6 +515,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
block_out_channels=block_out_channels,
|
||||
num_res_blocks=layers_per_block,
|
||||
ffactor_spatial=ffactor_spatial,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@@ -400,6 +524,7 @@ class HunyuanVAE2D(nn.Module):
|
||||
block_out_channels=list(reversed(block_out_channels)),
|
||||
num_res_blocks=layers_per_block,
|
||||
ffactor_spatial=ffactor_spatial,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Spatial tiling configuration for memory efficiency
|
||||
@@ -617,9 +742,9 @@ class HunyuanVAE2D(nn.Module):
|
||||
return decoded
|
||||
|
||||
|
||||
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False) -> HunyuanVAE2D:
|
||||
logger.info("Initializing VAE")
|
||||
vae = HunyuanVAE2D()
|
||||
def load_vae(vae_path: str, device: torch.device, disable_mmap: bool = False, chunk_size: Optional[int] = None) -> HunyuanVAE2D:
|
||||
logger.info(f"Initializing VAE with chunk_size={chunk_size}")
|
||||
vae = HunyuanVAE2D(chunk_size=chunk_size)
|
||||
|
||||
logger.info(f"Loading VAE from {vae_path}")
|
||||
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
|
||||
|
||||
@@ -626,6 +626,7 @@ class LatentsCachingStrategy:
|
||||
for key in npz.files:
|
||||
kwargs[key] = npz[key]
|
||||
|
||||
# TODO float() is needed if vae is in bfloat16. Remove it if vae is float16.
|
||||
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||
|
||||
Reference in New Issue
Block a user