Files
Kohya-ss-sd-scripts/library/qwen_image_autoencoder_kl.py
Kohya S. 34e7138b6a Add/modify some implementation for anima (#2261)
* fix: update extend-exclude list in _typos.toml to include configs

* fix: exclude anima tests from pytest

* feat: add entry for 'temperal' in extend-words section of _typos.toml for Qwen-Image VAE

* fix: update default value for --discrete_flow_shift in anima training guide

* feat: add Qwen-Image VAE

* feat: simplify encode_tokens

* feat: use unified attention module, add wrapper for state dict compatibility

* feat: loading with dynamic fp8 optimization and LoRA support

* feat: add anima minimal inference script (WIP)

* format: format

* feat: simplify target module selection by regular expression patterns

* feat: kept caption dropout rate in cache and handle in training script

* feat: update train_llm_adapter and verbose default values to string type

* fix: use strategy instead of using tokenizers directly

* feat: add dtype property and all-zero mask handling in cross-attention in LLMAdapterTransformerBlock

* feat: support 5d tensor in get_noisy_model_input_and_timesteps

* feat: update loss calculation to support 5d tensor

* fix: update argument names in anima_train_utils to align with other archtectures

* feat: simplify Anima training script and update empty caption handling

* feat: support LoRA format without `net.` prefix

* fix: update to work fp8_scaled option

* feat: add regex-based learning rates and dimensions handling in create_network

* fix: improve regex matching for module selection and learning rates in LoRANetwork

* fix: update logging message for regex match in LoRANetwork

* fix: keep latents 4D except DiT call

* feat: enhance block swap functionality for inference and training in Anima model

* feat: refactor Anima training script

* feat: optimize VAE processing by adjusting tensor dimensions and data types

* fix: wait all block trasfer before siwtching offloader mode

* feat: update Anima training guide with new argument specifications and regex-based module selection. Thank you Claude!

* feat: support LORA for Qwen3

* feat: update Anima SAI model spec metadata handling

* fix: remove unused code

* feat: split CFG processing in do_sample function to reduce memory usage

* feat: add VAE chunking and caching options to reduce memory usage

* feat: optimize RMSNorm forward method and remove unused torch_attention_op

* Update library/strategy_anima.py

Use torch.all instead of all.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/safetensors_utils.py

Fix duplicated new_key for concat_hook.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_minimal_inference.py

Remove unused code.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update anima_train.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update library/anima_train_utils.py

Remove unused import.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: review with Copilot

* feat: add script to convert LoRA format to ComfyUI compatible format (WIP, not tested yet)

* feat: add process_escape function to handle escape sequences in prompts

* feat: enhance LoRA weight handling in model loading and add text encoder loading function

* feat: improve ComfyUI conversion script with prefix constants and module name adjustments

* feat: update caption dropout documentation to clarify cache regeneration requirement

* feat: add clarification on learning rate adjustments

* feat: add note on PyTorch version requirement to prevent NaN loss

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-13 08:15:06 +09:00

1736 lines
69 KiB
Python

# Copied and modified from Diffusers (via Musubi-Tuner). Original copyright notice follows.
# Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We gratefully acknowledge the Wan Team for their outstanding contributions.
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
# For more information about the Wan VAE, please refer to:
# - GitHub: https://github.com/Wan-Video/Wan2.1
# - arXiv: https://arxiv.org/abs/2503.20314
import json
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from library.safetensors_utils import load_safetensors
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
CACHE_T = 2
SCALE_FACTOR = 8 # VAE downsampling factor
# region diffusers-vae
class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
# make sure sample is on the same device as the parameters and has same dtype
if generator is not None and generator.device.type != self.parameters.device.type:
rand_device = generator.device
else:
rand_device = self.parameters.device
sample = torch.randn(self.mean.shape, generator=generator, device=rand_device, dtype=self.parameters.dtype).to(
self.parameters.device
)
x = self.mean + self.std * sample
return x
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3],
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
# endregion diffusers-vae
class ChunkedConv2d(nn.Conv2d):
"""
Convolutional layer that processes input in chunks to reduce memory usage.
Parameters
----------
spatial_chunk_size : int, optional
Size of chunks to process at a time. Default is None, which means no chunking.
TODO: Commonize with similar implementation in hunyuan_image_vae.py
"""
def __init__(self, *args, **kwargs):
if "spatial_chunk_size" in kwargs:
self.spatial_chunk_size = kwargs.pop("spatial_chunk_size", None)
else:
self.spatial_chunk_size = None
super().__init__(*args, **kwargs)
assert self.padding_mode == "zeros", "Only 'zeros' padding mode is supported."
assert self.dilation == (1, 1), "Only dilation=1 is supported."
assert self.groups == 1, "Only groups=1 is supported."
assert self.kernel_size[0] == self.kernel_size[1], "Only square kernels are supported."
assert self.stride[0] == self.stride[1], "Only equal strides are supported."
self.original_padding = self.padding
self.padding = (0, 0) # We handle padding manually in forward
def forward(self, x: torch.Tensor) -> torch.Tensor:
# If chunking is not needed, process normally. We chunk only along height dimension.
if (
self.spatial_chunk_size is None
or x.shape[2] <= self.spatial_chunk_size + self.kernel_size[0] + self.spatial_chunk_size // 4
):
self.padding = self.original_padding
x = super().forward(x)
self.padding = (0, 0)
return x
# Process input in chunks to reduce memory usage
org_shape = x.shape
# If kernel size is not 1, we need to use overlapping chunks
overlap = self.kernel_size[0] // 2 # 1 for kernel size 3
if self.original_padding[0] == 0:
overlap = 0
# If stride > 1, QwenImageVAE pads manually with zeros before convolution, so we do not need to consider it here
y_height = org_shape[2] // self.stride[0]
y_width = org_shape[3] // self.stride[1]
y = torch.zeros((org_shape[0], self.out_channels, y_height, y_width), dtype=x.dtype, device=x.device)
yi = 0
i = 0
while i < org_shape[2]:
si = i if i == 0 else i - overlap
ei = i + self.spatial_chunk_size + overlap + self.stride[0] - 1
# Check last chunk. If remaining part is small, include it in last chunk
if ei > org_shape[2] or ei + self.spatial_chunk_size // 4 > org_shape[2]:
ei = org_shape[2]
chunk = x[:, :, si:ei, :]
# Pad chunk if needed: This is as the original Conv2d with padding
if i == 0 and overlap > 0: # First chunk
# Pad except bottom
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, overlap, 0), mode="constant", value=0)
elif ei == org_shape[2] and overlap > 0: # Last chunk
# Pad except top
chunk = torch.nn.functional.pad(chunk, (overlap, overlap, 0, overlap), mode="constant", value=0)
elif overlap > 0: # Middle chunks
# Pad left and right only
chunk = torch.nn.functional.pad(chunk, (overlap, overlap), mode="constant", value=0)
# print(f"Processing chunk: org_shape={org_shape}, si={si}, ei={ei}, chunk.shape={chunk.shape}, overlap={overlap}")
chunk = super().forward(chunk)
# print(f" -> chunk after conv shape: {chunk.shape}")
y[:, :, yi : yi + chunk.shape[2], :] = chunk
yi += chunk.shape[2]
del chunk
if ei == org_shape[2]:
break
i += self.spatial_chunk_size
assert yi == y_height, f"yi={yi}, y_height={y_height}"
return y
class QwenImageCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
spatial_chunk_size: Optional[int] = None,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
self.spatial_chunk_size = spatial_chunk_size
self._supports_spatial_chunking = (
self.groups == 1 and self.dilation[1] == 1 and self.dilation[2] == 1 and self.stride[1] == 1 and self.stride[2] == 1
)
def _forward_chunked_height(self, x: torch.Tensor) -> torch.Tensor:
chunk_size = self.spatial_chunk_size
if chunk_size is None or chunk_size <= 0:
return super().forward(x)
if not self._supports_spatial_chunking:
return super().forward(x)
kernel_h = self.kernel_size[1]
if kernel_h <= 1 or x.shape[3] <= chunk_size:
return super().forward(x)
receptive_h = kernel_h
out_h = x.shape[3] - receptive_h + 1
if out_h <= 0:
return super().forward(x)
y0 = 0
out = None
while y0 < out_h:
y1 = min(y0 + chunk_size, out_h)
in0 = y0
in1 = y1 + receptive_h - 1
out_chunk = super().forward(x[:, :, :, in0:in1, :])
if out is None:
out_shape = list(out_chunk.shape)
out_shape[3] = out_h
out = out_chunk.new_empty(out_shape)
out[:, :, :, y0:y1, :] = out_chunk
y0 = y1
return out
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return self._forward_chunked_height(x)
class QwenImageRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class QwenImageUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class QwenImageResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ChunkedConv2d(dim, dim // 2, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
ChunkedConv2d(dim, dim // 2, 3, padding=1),
)
self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), ChunkedConv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class QwenImageResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently."
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = nn.SiLU() # get_activation(non_linearity)
# layers
self.norm1 = QwenImageRMS_norm(in_dim, images=False)
self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = QwenImageRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class QwenImageAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = QwenImageRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class QwenImageMidBlock(nn.Module):
"""
Middle block for QwenImageVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
super().__init__()
self.dim = dim
# Create the components
resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(QwenImageAttentionBlock(dim))
resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class QwenImageEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
input_channels (int): Number of input channels.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
input_channels: int = 3,
non_linearity: str = "silu",
):
super().__init__()
assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently."
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = nn.SiLU() # get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = QwenImageCausalConv3d(input_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(QwenImageAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class QwenImageUpBlock(nn.Module):
"""
A block that handles upsampling for the QwenImageVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class QwenImageDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
output_channels (int): Number of output channels.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
output_channels: int = 3,
non_linearity: str = "silu",
):
super().__init__()
assert non_linearity in ["silu"], "Only 'silu' non-linearity is supported currently."
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = nn.SiLU() # get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
in_dim = in_dim // 2
# Determine if we need upsampling
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
# Create and add the upsampling block
up_block = QwenImageUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = QwenImageRMS_norm(out_dim, images=False)
self.conv_out = QwenImageCausalConv3d(out_dim, output_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class AutoencoderKLQwenImage(nn.Module): # ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
# @register_to_config
def __init__(
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
input_channels: int = 3,
spatial_chunk_size: Optional[int] = None,
disable_cache: bool = False,
) -> None:
super().__init__()
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.latents_mean = latents_mean
self.latents_std = latents_std
self.encoder = QwenImageEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout, input_channels
)
self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
self.decoder = QwenImageDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout, input_channels
)
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self.use_slicing = False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self.use_tiling = False
# The minimal tile height and width for spatial tiling to be used
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
# The minimal distance between two spatial tiles
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules()) if self.decoder is not None else 0,
"encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules()) if self.encoder is not None else 0,
}
self.spatial_chunk_size = None
if spatial_chunk_size is not None and spatial_chunk_size > 0:
self.enable_spatial_chunking(spatial_chunk_size)
self.cache_disabled = False
if disable_cache:
self.disable_cache()
@property
def dtype(self):
return self.encoder.parameters().__next__().dtype
@property
def device(self):
return self.encoder.parameters().__next__().device
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_sample_stride_height: Optional[float] = None,
tile_sample_stride_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def enable_spatial_chunking(self, spatial_chunk_size: int) -> None:
r"""
Enable memory-efficient convolution by chunking all causal Conv3d layers only along height.
"""
if spatial_chunk_size is None or spatial_chunk_size <= 0:
raise ValueError(f"`spatial_chunk_size` must be a positive integer, got {spatial_chunk_size}.")
self.spatial_chunk_size = int(spatial_chunk_size)
for module in self.modules():
if isinstance(module, QwenImageCausalConv3d):
module.spatial_chunk_size = self.spatial_chunk_size
elif isinstance(module, ChunkedConv2d):
module.spatial_chunk_size = self.spatial_chunk_size
def disable_spatial_chunking(self) -> None:
r"""
Disable memory-efficient convolution chunking on all causal Conv3d layers.
"""
self.spatial_chunk_size = None
for module in self.modules():
if isinstance(module, QwenImageCausalConv3d):
module.spatial_chunk_size = None
elif isinstance(module, ChunkedConv2d):
module.spatial_chunk_size = None
def disable_cache(self) -> None:
r"""
Disable caching mechanism in encoder and decoder.
"""
self.cache_disabled = True
self.clear_cache = lambda: None
self._feat_map = None # Disable decoder cache
self._enc_feat_map = None # Disable encoder cache
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, QwenImageCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames."
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
self.clear_cache()
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
self.clear_cache()
return enc
# @apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[Dict[str, torch.Tensor], Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a dictionary is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return {"latent_dist": posterior}
def _decode(self, z: torch.Tensor, return_dict: bool = True):
_, _, num_frame, height, width = z.shape
assert num_frame == 1 or not self.cache_disabled, "Caching must be enabled for encoding multiple frames."
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
self.clear_cache()
x = self.post_quant_conv(z)
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)
return {"sample": out}
# @apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice)["sample"] for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)["sample"]
if not return_dict:
return (decoded,)
return {"sample": decoded}
def decode_to_pixels(self, latents: torch.Tensor) -> torch.Tensor:
is_4d = latents.dim() == 4
if is_4d:
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
latents = latents.to(self.dtype)
latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
image = self.decode(latents, return_dict=False)[0] # -1 to 1
if is_4d:
image = image.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
return image.clamp(-1.0, 1.0)
def encode_pixels_to_latents(self, pixels: torch.Tensor) -> torch.Tensor:
"""
Convert pixel values to latents and apply normalization using mean/std.
Args:
pixels (torch.Tensor): Input pixels in [0, 1] range with shape [B, C, H, W] or [B, C, T, H, W]
Returns:
torch.Tensor: Normalized latents
"""
# # Convert from [0, 1] to [-1, 1] range
# pixels = (pixels * 2.0 - 1.0).clamp(-1.0, 1.0)
# Handle 2D input by adding temporal dimension
is_4d = pixels.dim() == 4
if is_4d:
pixels = pixels.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]
pixels = pixels.to(self.dtype)
# Encode to latent space
posterior = self.encode(pixels, return_dict=False)[0]
latents = posterior.mode() # Use mode instead of sampling for deterministic results
# latents = posterior.sample()
# Apply normalization using mean/std
latents_mean = torch.tensor(self.latents_mean).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / torch.tensor(self.latents_std).view(1, self.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * latents_std
if is_4d:
latents = latents.squeeze(2) # [B, C, 1, H, W] -> [B, C, H, W]
return latents
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_, _, num_frames, height, width = x.shape
latent_height = height // self.spatial_compression_ratio
latent_width = width // self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = tile_latent_min_height - tile_latent_stride_height
blend_width = tile_latent_min_width - tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, self.tile_sample_stride_height):
row = []
for j in range(0, width, self.tile_sample_stride_width):
self.clear_cache()
time = []
frame_range = 1 + (num_frames - 1) // 4
for k in range(frame_range):
self._enc_conv_idx = [0]
if k == 0:
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
else:
tile = x[
:,
:,
1 + 4 * (k - 1) : 1 + 4 * k,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
tile = self.quant_conv(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a dictionary instead of a plain tuple.
Returns:
`dict` or `tuple`:
If return_dict is True, a dictionary is returned, otherwise a plain `tuple` is
returned.
"""
_, _, num_frames, height, width = z.shape
sample_height = height * self.spatial_compression_ratio
sample_width = width * self.spatial_compression_ratio
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, tile_latent_stride_height):
row = []
for j in range(0, width, tile_latent_stride_width):
self.clear_cache()
time = []
for k in range(num_frames):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
self.clear_cache()
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
if not return_dict:
return (dec,)
return {"sample": dec}
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[Dict[str, torch.Tensor], torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`Dict[str, torch.Tensor]`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
# region utils
# This region is not included in the original implementation. Added for musubi-tuner/sd-scripts.
# Convert ComfyUI keys to standard keys if necessary
def convert_comfyui_state_dict(sd):
if "conv1.bias" not in sd:
return sd
# Key mapping from ComfyUI VAE to official VAE, auto-generated by a script
key_map = {
"conv1": "quant_conv",
"conv2": "post_quant_conv",
"decoder.conv1": "decoder.conv_in",
"decoder.head.0": "decoder.norm_out",
"decoder.head.2": "decoder.conv_out",
"decoder.middle.0.residual.0": "decoder.mid_block.resnets.0.norm1",
"decoder.middle.0.residual.2": "decoder.mid_block.resnets.0.conv1",
"decoder.middle.0.residual.3": "decoder.mid_block.resnets.0.norm2",
"decoder.middle.0.residual.6": "decoder.mid_block.resnets.0.conv2",
"decoder.middle.1.norm": "decoder.mid_block.attentions.0.norm",
"decoder.middle.1.proj": "decoder.mid_block.attentions.0.proj",
"decoder.middle.1.to_qkv": "decoder.mid_block.attentions.0.to_qkv",
"decoder.middle.2.residual.0": "decoder.mid_block.resnets.1.norm1",
"decoder.middle.2.residual.2": "decoder.mid_block.resnets.1.conv1",
"decoder.middle.2.residual.3": "decoder.mid_block.resnets.1.norm2",
"decoder.middle.2.residual.6": "decoder.mid_block.resnets.1.conv2",
"decoder.upsamples.0.residual.0": "decoder.up_blocks.0.resnets.0.norm1",
"decoder.upsamples.0.residual.2": "decoder.up_blocks.0.resnets.0.conv1",
"decoder.upsamples.0.residual.3": "decoder.up_blocks.0.resnets.0.norm2",
"decoder.upsamples.0.residual.6": "decoder.up_blocks.0.resnets.0.conv2",
"decoder.upsamples.1.residual.0": "decoder.up_blocks.0.resnets.1.norm1",
"decoder.upsamples.1.residual.2": "decoder.up_blocks.0.resnets.1.conv1",
"decoder.upsamples.1.residual.3": "decoder.up_blocks.0.resnets.1.norm2",
"decoder.upsamples.1.residual.6": "decoder.up_blocks.0.resnets.1.conv2",
"decoder.upsamples.10.residual.0": "decoder.up_blocks.2.resnets.2.norm1",
"decoder.upsamples.10.residual.2": "decoder.up_blocks.2.resnets.2.conv1",
"decoder.upsamples.10.residual.3": "decoder.up_blocks.2.resnets.2.norm2",
"decoder.upsamples.10.residual.6": "decoder.up_blocks.2.resnets.2.conv2",
"decoder.upsamples.11.resample.1": "decoder.up_blocks.2.upsamplers.0.resample.1",
"decoder.upsamples.12.residual.0": "decoder.up_blocks.3.resnets.0.norm1",
"decoder.upsamples.12.residual.2": "decoder.up_blocks.3.resnets.0.conv1",
"decoder.upsamples.12.residual.3": "decoder.up_blocks.3.resnets.0.norm2",
"decoder.upsamples.12.residual.6": "decoder.up_blocks.3.resnets.0.conv2",
"decoder.upsamples.13.residual.0": "decoder.up_blocks.3.resnets.1.norm1",
"decoder.upsamples.13.residual.2": "decoder.up_blocks.3.resnets.1.conv1",
"decoder.upsamples.13.residual.3": "decoder.up_blocks.3.resnets.1.norm2",
"decoder.upsamples.13.residual.6": "decoder.up_blocks.3.resnets.1.conv2",
"decoder.upsamples.14.residual.0": "decoder.up_blocks.3.resnets.2.norm1",
"decoder.upsamples.14.residual.2": "decoder.up_blocks.3.resnets.2.conv1",
"decoder.upsamples.14.residual.3": "decoder.up_blocks.3.resnets.2.norm2",
"decoder.upsamples.14.residual.6": "decoder.up_blocks.3.resnets.2.conv2",
"decoder.upsamples.2.residual.0": "decoder.up_blocks.0.resnets.2.norm1",
"decoder.upsamples.2.residual.2": "decoder.up_blocks.0.resnets.2.conv1",
"decoder.upsamples.2.residual.3": "decoder.up_blocks.0.resnets.2.norm2",
"decoder.upsamples.2.residual.6": "decoder.up_blocks.0.resnets.2.conv2",
"decoder.upsamples.3.resample.1": "decoder.up_blocks.0.upsamplers.0.resample.1",
"decoder.upsamples.3.time_conv": "decoder.up_blocks.0.upsamplers.0.time_conv",
"decoder.upsamples.4.residual.0": "decoder.up_blocks.1.resnets.0.norm1",
"decoder.upsamples.4.residual.2": "decoder.up_blocks.1.resnets.0.conv1",
"decoder.upsamples.4.residual.3": "decoder.up_blocks.1.resnets.0.norm2",
"decoder.upsamples.4.residual.6": "decoder.up_blocks.1.resnets.0.conv2",
"decoder.upsamples.4.shortcut": "decoder.up_blocks.1.resnets.0.conv_shortcut",
"decoder.upsamples.5.residual.0": "decoder.up_blocks.1.resnets.1.norm1",
"decoder.upsamples.5.residual.2": "decoder.up_blocks.1.resnets.1.conv1",
"decoder.upsamples.5.residual.3": "decoder.up_blocks.1.resnets.1.norm2",
"decoder.upsamples.5.residual.6": "decoder.up_blocks.1.resnets.1.conv2",
"decoder.upsamples.6.residual.0": "decoder.up_blocks.1.resnets.2.norm1",
"decoder.upsamples.6.residual.2": "decoder.up_blocks.1.resnets.2.conv1",
"decoder.upsamples.6.residual.3": "decoder.up_blocks.1.resnets.2.norm2",
"decoder.upsamples.6.residual.6": "decoder.up_blocks.1.resnets.2.conv2",
"decoder.upsamples.7.resample.1": "decoder.up_blocks.1.upsamplers.0.resample.1",
"decoder.upsamples.7.time_conv": "decoder.up_blocks.1.upsamplers.0.time_conv",
"decoder.upsamples.8.residual.0": "decoder.up_blocks.2.resnets.0.norm1",
"decoder.upsamples.8.residual.2": "decoder.up_blocks.2.resnets.0.conv1",
"decoder.upsamples.8.residual.3": "decoder.up_blocks.2.resnets.0.norm2",
"decoder.upsamples.8.residual.6": "decoder.up_blocks.2.resnets.0.conv2",
"decoder.upsamples.9.residual.0": "decoder.up_blocks.2.resnets.1.norm1",
"decoder.upsamples.9.residual.2": "decoder.up_blocks.2.resnets.1.conv1",
"decoder.upsamples.9.residual.3": "decoder.up_blocks.2.resnets.1.norm2",
"decoder.upsamples.9.residual.6": "decoder.up_blocks.2.resnets.1.conv2",
"encoder.conv1": "encoder.conv_in",
"encoder.downsamples.0.residual.0": "encoder.down_blocks.0.norm1",
"encoder.downsamples.0.residual.2": "encoder.down_blocks.0.conv1",
"encoder.downsamples.0.residual.3": "encoder.down_blocks.0.norm2",
"encoder.downsamples.0.residual.6": "encoder.down_blocks.0.conv2",
"encoder.downsamples.1.residual.0": "encoder.down_blocks.1.norm1",
"encoder.downsamples.1.residual.2": "encoder.down_blocks.1.conv1",
"encoder.downsamples.1.residual.3": "encoder.down_blocks.1.norm2",
"encoder.downsamples.1.residual.6": "encoder.down_blocks.1.conv2",
"encoder.downsamples.10.residual.0": "encoder.down_blocks.10.norm1",
"encoder.downsamples.10.residual.2": "encoder.down_blocks.10.conv1",
"encoder.downsamples.10.residual.3": "encoder.down_blocks.10.norm2",
"encoder.downsamples.10.residual.6": "encoder.down_blocks.10.conv2",
"encoder.downsamples.2.resample.1": "encoder.down_blocks.2.resample.1",
"encoder.downsamples.3.residual.0": "encoder.down_blocks.3.norm1",
"encoder.downsamples.3.residual.2": "encoder.down_blocks.3.conv1",
"encoder.downsamples.3.residual.3": "encoder.down_blocks.3.norm2",
"encoder.downsamples.3.residual.6": "encoder.down_blocks.3.conv2",
"encoder.downsamples.3.shortcut": "encoder.down_blocks.3.conv_shortcut",
"encoder.downsamples.4.residual.0": "encoder.down_blocks.4.norm1",
"encoder.downsamples.4.residual.2": "encoder.down_blocks.4.conv1",
"encoder.downsamples.4.residual.3": "encoder.down_blocks.4.norm2",
"encoder.downsamples.4.residual.6": "encoder.down_blocks.4.conv2",
"encoder.downsamples.5.resample.1": "encoder.down_blocks.5.resample.1",
"encoder.downsamples.5.time_conv": "encoder.down_blocks.5.time_conv",
"encoder.downsamples.6.residual.0": "encoder.down_blocks.6.norm1",
"encoder.downsamples.6.residual.2": "encoder.down_blocks.6.conv1",
"encoder.downsamples.6.residual.3": "encoder.down_blocks.6.norm2",
"encoder.downsamples.6.residual.6": "encoder.down_blocks.6.conv2",
"encoder.downsamples.6.shortcut": "encoder.down_blocks.6.conv_shortcut",
"encoder.downsamples.7.residual.0": "encoder.down_blocks.7.norm1",
"encoder.downsamples.7.residual.2": "encoder.down_blocks.7.conv1",
"encoder.downsamples.7.residual.3": "encoder.down_blocks.7.norm2",
"encoder.downsamples.7.residual.6": "encoder.down_blocks.7.conv2",
"encoder.downsamples.8.resample.1": "encoder.down_blocks.8.resample.1",
"encoder.downsamples.8.time_conv": "encoder.down_blocks.8.time_conv",
"encoder.downsamples.9.residual.0": "encoder.down_blocks.9.norm1",
"encoder.downsamples.9.residual.2": "encoder.down_blocks.9.conv1",
"encoder.downsamples.9.residual.3": "encoder.down_blocks.9.norm2",
"encoder.downsamples.9.residual.6": "encoder.down_blocks.9.conv2",
"encoder.head.0": "encoder.norm_out",
"encoder.head.2": "encoder.conv_out",
"encoder.middle.0.residual.0": "encoder.mid_block.resnets.0.norm1",
"encoder.middle.0.residual.2": "encoder.mid_block.resnets.0.conv1",
"encoder.middle.0.residual.3": "encoder.mid_block.resnets.0.norm2",
"encoder.middle.0.residual.6": "encoder.mid_block.resnets.0.conv2",
"encoder.middle.1.norm": "encoder.mid_block.attentions.0.norm",
"encoder.middle.1.proj": "encoder.mid_block.attentions.0.proj",
"encoder.middle.1.to_qkv": "encoder.mid_block.attentions.0.to_qkv",
"encoder.middle.2.residual.0": "encoder.mid_block.resnets.1.norm1",
"encoder.middle.2.residual.2": "encoder.mid_block.resnets.1.conv1",
"encoder.middle.2.residual.3": "encoder.mid_block.resnets.1.norm2",
"encoder.middle.2.residual.6": "encoder.mid_block.resnets.1.conv2",
}
new_state_dict = {}
for key in sd.keys():
new_key = key
key_without_suffix = key.rsplit(".", 1)[0]
if key_without_suffix in key_map:
new_key = key.replace(key_without_suffix, key_map[key_without_suffix])
new_state_dict[new_key] = sd[key]
logger.info("Converted ComfyUI AutoencoderKL state dict keys to official format")
return new_state_dict
def load_vae(
vae_path: str,
input_channels: int = 3,
device: Union[str, torch.device] = "cpu",
disable_mmap: bool = False,
spatial_chunk_size: Optional[int] = None,
disable_cache: bool = False,
) -> AutoencoderKLQwenImage:
"""Load VAE from a given path."""
VAE_CONFIG_JSON = """
{
"_class_name": "AutoencoderKLQwenImage",
"_diffusers_version": "0.34.0.dev0",
"attn_scales": [],
"base_dim": 96,
"dim_mult": [
1,
2,
4,
4
],
"dropout": 0.0,
"latents_mean": [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921
],
"latents_std": [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.916
],
"num_res_blocks": 2,
"temperal_downsample": [
false,
true,
true
],
"z_dim": 16
}
"""
logger.info("Initializing VAE")
if spatial_chunk_size is not None and spatial_chunk_size % 2 != 0:
spatial_chunk_size += 1
logger.warning(f"Adjusted spatial_chunk_size to the next even number: {spatial_chunk_size}")
config = json.loads(VAE_CONFIG_JSON)
vae = AutoencoderKLQwenImage(
base_dim=config["base_dim"],
z_dim=config["z_dim"],
dim_mult=config["dim_mult"],
num_res_blocks=config["num_res_blocks"],
attn_scales=config["attn_scales"],
temperal_downsample=config["temperal_downsample"],
dropout=config["dropout"],
latents_mean=config["latents_mean"],
latents_std=config["latents_std"],
input_channels=input_channels,
spatial_chunk_size=spatial_chunk_size,
disable_cache=disable_cache,
)
logger.info(f"Loading VAE from {vae_path}")
state_dict = load_safetensors(vae_path, device=device, disable_mmap=disable_mmap)
# Convert ComfyUI VAE keys to official VAE keys
state_dict = convert_comfyui_state_dict(state_dict)
info = vae.load_state_dict(state_dict, strict=True, assign=True)
logger.info(f"Loaded VAE: {info}")
vae.to(device)
return vae
if __name__ == "__main__":
# Debugging / testing code
import argparse
import glob
import os
import time
from PIL import Image
from library.device_utils import get_preferred_device, synchronize_device
parser = argparse.ArgumentParser()
parser.add_argument("--vae", type=str, required=True, help="Path to the VAE model file.")
parser.add_argument("--input_image_dir", type=str, required=True, help="Path to the input image directory.")
parser.add_argument("--output_image_dir", type=str, required=True, help="Path to the output image directory.")
args = parser.parse_args()
# Load VAE
vae = load_vae(args.vae, device=get_preferred_device())
# Process images
def encode_decode_image(image_path, output_path):
image = Image.open(image_path).convert("RGB")
# Crop to multiple of 8
width, height = image.size
new_width = (width // 8) * 8
new_height = (height // 8) * 8
if new_width != width or new_height != height:
image = image.crop((0, 0, new_width, new_height))
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 * 2 - 1
image_tensor = image_tensor.to(vae.dtype).to(vae.device)
with torch.no_grad():
latents = vae.encode_pixels_to_latents(image_tensor)
reconstructed = vae.decode_to_pixels(latents)
diff = (image_tensor - reconstructed).abs().mean().item()
print(f"Processed {image_path} (size: {image.size}), reconstruction diff: {diff}")
reconstructed_image = ((reconstructed.squeeze(0).permute(1, 2, 0).float().cpu().numpy() + 1) / 2 * 255).astype(np.uint8)
Image.fromarray(reconstructed_image).save(output_path)
def process_directory(input_dir, output_dir):
if get_preferred_device().type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
synchronize_device(get_preferred_device())
start_time = time.perf_counter()
os.makedirs(output_dir, exist_ok=True)
image_paths = glob.glob(os.path.join(input_dir, "*.jpg")) + glob.glob(os.path.join(input_dir, "*.png"))
for image_path in image_paths:
filename = os.path.basename(image_path)
output_path = os.path.join(output_dir, filename)
encode_decode_image(image_path, output_path)
if get_preferred_device().type == "cuda":
max_mem = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max GPU memory allocated: {max_mem:.2f} GB")
synchronize_device(get_preferred_device())
end_time = time.perf_counter()
print(f"Processing time: {end_time - start_time:.2f} seconds")
print("Starting image processing with default settings...")
process_directory(args.input_image_dir, args.output_image_dir)
print("Starting image processing with spatial chunking enabled with chunk size 64...")
vae.enable_spatial_chunking(64)
process_directory(args.input_image_dir, args.output_image_dir + "_chunked_64")
print("Starting image processing with spatial chunking enabled with chunk size 16...")
vae.enable_spatial_chunking(16)
process_directory(args.input_image_dir, args.output_image_dir + "_chunked_16")
print("Starting image processing without caching and chunking enabled with chunk size 64...")
vae.enable_spatial_chunking(64)
vae.disable_cache()
process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_64")
print("Starting image processing without caching and chunking enabled with chunk size 16...")
vae.disable_cache()
process_directory(args.input_image_dir, args.output_image_dir + "_no_cache_chunked_16")
print("Starting image processing without caching and chunking disabled...")
vae.disable_spatial_chunking()
process_directory(args.input_image_dir, args.output_image_dir + "_no_cache")
print("Processing completed.")