Files
Kohya-ss-sd-scripts/library/sd3_models.py
2024-10-24 22:02:05 +09:00

1234 lines
43 KiB
Python

# some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref
# the original code is licensed under the MIT License
# and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution!
from ast import Tuple
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from functools import partial
import math
from types import SimpleNamespace
from typing import Dict, List, Optional, Union
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import CLIPTokenizer, T5TokenizerFast
from library.device_utils import clean_memory_on_device
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
memory_efficient_attention = None
try:
import xformers
except:
pass
try:
from xformers.ops import memory_efficient_attention
except:
memory_efficient_attention = None
# region mmdit
@dataclass
class SD3Params:
patch_size: int
depth: int
num_patches: int
pos_embed_max_size: int
adm_in_channels: int
qk_norm: Optional[str]
x_block_self_attn_layers: List[int]
context_embedder_in_features: int
context_embedder_out_features: int
model_type: str
def get_2d_sincos_pos_embed(
embed_dim,
grid_size,
scaling_factor=None,
offset=None,
):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
if scaling_factor is not None:
grid = grid / scaling_factor
if offset is not None:
grid = grid - offset
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_1d_sincos_pos_embed_from_grid_torch(
embed_dim,
pos,
device=None,
dtype=torch.float32,
):
omega = torch.arange(embed_dim // 2, device=device, dtype=dtype)
omega *= 2.0 / embed_dim
omega = 1.0 / 10000**omega
out = torch.outer(pos.reshape(-1), omega)
emb = torch.cat([out.sin(), out.cos()], dim=1)
return emb
def get_2d_sincos_pos_embed_torch(
embed_dim,
w,
h,
val_center=7.5,
val_magnitude=7.5,
device=None,
dtype=torch.float32,
):
small = min(h, w)
val_h = (h / small) * val_magnitude
val_w = (w / small) * val_magnitude
grid_h, grid_w = torch.meshgrid(
torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype),
torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype),
indexing="ij",
)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
return emb
def modulate(x, shift, scale):
if shift is None:
shift = torch.zeros_like(scale)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def default(x, default_value):
if x is None:
return default_value
return x
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
# device=t.device, dtype=t.dtype
# )
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(dtype=t.dtype)
return embedding
class PatchEmbed(nn.Module):
def __init__(
self,
img_size=256,
patch_size=4,
in_channels=3,
embed_dim=512,
norm_layer=None,
flatten=True,
bias=True,
strict_img_size=True,
dynamic_img_pad=False,
):
# dynamic_img_pad and norm is omitted in SD3.5
super().__init__()
self.patch_size = patch_size
self.flatten = flatten
self.strict_img_size = strict_img_size
self.dynamic_img_pad = dynamic_img_pad
if img_size is not None:
self.img_size = img_size
self.grid_size = img_size // patch_size
self.num_patches = self.grid_size**2
else:
self.img_size = None
self.grid_size = None
self.num_patches = None
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim)
def forward(self, x):
B, C, H, W = x.shape
if self.dynamic_img_pad:
# Pad input so we won't have partial patch
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect")
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x
# FinalLayer in mmdit.py
class UnPatch(nn.Module):
def __init__(self, hidden_size=512, patch_size=4, out_channels=3):
super().__init__()
self.patch_size = patch_size
self.c = out_channels
# eps is default in mmdit.py
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size),
)
def forward(self, x: torch.Tensor, cmod, H=None, W=None):
b, n, _ = x.shape
p = self.patch_size
c = self.c
if H is None and W is None:
w = h = int(n**0.5)
assert h * w == n
else:
h = H // p if H else n // (W // p)
w = W // p if W else n // h
assert h * w == n
shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
x = x.view(b, h, w, p, p, c)
x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
x = x.view(b, c, h * p, w * p)
return x
class MLP(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=lambda: nn.GELU(),
norm_layer=None,
bias=True,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.use_conv = use_conv
layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = layer(in_features, hidden_features, bias=bias)
self.fc2 = layer(hidden_features, out_features, bias=bias)
self.act = act_layer()
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.norm(x)
x = self.fc2(x)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, hidden_size, freq_embed_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(freq_embed_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.freq_embed_size = freq_embed_size
def forward(self, t, dtype=None, **kwargs):
t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb
class Embedder(nn.Module):
def __init__(self, input_dim, hidden_size):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x):
return self.mlp(x)
def rmsnorm(x, eps=1e-6):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
class RMSNorm(torch.nn.Module):
def __init__(
self,
dim: int,
elementwise_affine: bool = False,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
super().__init__()
self.eps = eps
self.learnable_scale = elementwise_affine
if self.learnable_scale:
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
else:
self.register_parameter("weight", None)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = rmsnorm(x, eps=self.eps)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
class SwiGLUFeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: float = None,
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
# Linears for SelfAttention in mmdit.py
class AttentionLinears(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
pre_only: bool = False,
qk_norm: Optional[str] = None,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
if not pre_only:
self.proj = nn.Linear(dim, dim)
self.pre_only = pre_only
if qk_norm == "rms":
self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
elif qk_norm == "ln":
self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6)
elif qk_norm is None:
self.ln_q = nn.Identity()
self.ln_k = nn.Identity()
else:
raise ValueError(qk_norm)
def pre_attention(self, x: torch.Tensor) -> torch.Tensor:
"""
output:
q, k, v: [B, L, D]
"""
B, L, C = x.shape
qkv: torch.Tensor = self.qkv(x)
q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2)
q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1)
k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1)
return (q, k, v)
def post_attention(self, x: torch.Tensor) -> torch.Tensor:
assert not self.pre_only
x = self.proj(x)
return x
MEMORY_LAYOUTS = {
"torch": (
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
lambda x: (1, x, 1, 1),
),
"xformers": (
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim),
lambda x: x.reshape(x.shape[0], x.shape[1], -1),
lambda x: (1, 1, x, 1),
),
"math": (
lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2),
lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1),
lambda x: (1, x, 1, 1),
),
}
# ATTN_FUNCTION = {
# "torch": F.scaled_dot_product_attention,
# "xformers": memory_efficient_attention,
# }
def vanilla_attention(q, k, v, mask, scale=None):
if scale is None:
scale = math.sqrt(q.size(-1))
scores = torch.bmm(q, k.transpose(-1, -2)) / scale
if mask is not None:
mask = einops.rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(scores.dtype).max
mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3))
scores = scores.masked_fill(~mask, max_neg_value)
p_attn = F.softmax(scores, dim=-1)
return torch.bmm(p_attn, v)
def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"):
"""
q, k, v: [B, L, D]
"""
pre_attn_layout = MEMORY_LAYOUTS[mode][0]
post_attn_layout = MEMORY_LAYOUTS[mode][1]
q = pre_attn_layout(q, head_dim)
k = pre_attn_layout(k, head_dim)
v = pre_attn_layout(v, head_dim)
# scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale)
if mode == "torch":
assert scale is None
scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale)
elif mode == "xformers":
scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale)
else:
scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale)
scores = post_attn_layout(scores)
return scores
# DismantledBlock in mmdit.py
class SingleDiTBlock(nn.Module):
"""
A DiT block with gated adaptive layer norm (adaLN) conditioning.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: str = "xformers",
qkv_bias: bool = False,
pre_only: bool = False,
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
qk_norm: Optional[str] = None,
**block_kwargs,
):
super().__init__()
assert attn_mode in MEMORY_LAYOUTS
self.attn_mode = attn_mode
if not rmsnorm:
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
else:
self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = AttentionLinears(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
pre_only=pre_only,
qk_norm=qk_norm,
)
if not pre_only:
if not rmsnorm:
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
else:
self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
if not pre_only:
if not swiglu:
self.mlp = MLP(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=lambda: nn.GELU(approximate="tanh"),
)
else:
self.mlp = SwiGLUFeedForward(
dim=hidden_size,
hidden_dim=mlp_hidden_dim,
multiple_of=256,
)
self.scale_mod_only = scale_mod_only
if not scale_mod_only:
n_mods = 6 if not pre_only else 2
else:
n_mods = 4 if not pre_only else 1
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size))
self.pre_only = pre_only
def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
if not self.pre_only:
if not self.scale_mod_only:
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(
c
).chunk(6, dim=-1)
else:
shift_msa = None
shift_mlp = None
(
scale_msa,
gate_msa,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(
c
).chunk(4, dim=-1)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, (
x,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
)
else:
if not self.scale_mod_only:
(
shift_msa,
scale_msa,
) = self.adaLN_modulation(
c
).chunk(2, dim=-1)
else:
shift_msa = None
scale_msa = self.adaLN_modulation(c)
qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa))
return qkv, None
def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp):
assert not self.pre_only
x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn)
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
# JointBlock + block_mixing in mmdit.py
class MMDiTBlock(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
pre_only = kwargs.pop("pre_only")
self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs)
self.x_block = SingleDiTBlock(*args, pre_only=False, **kwargs)
self.head_dim = self.x_block.attn.head_dim
self.mode = self.x_block.attn_mode
self.gradient_checkpointing = False
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
def _forward(self, context, x, c):
ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c)
x_qkv, x_intermediate = self.x_block.pre_attention(x, c)
ctx_len = ctx_qkv[0].size(1)
q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1)
k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1)
v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1)
attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode)
ctx_attn_out = attn[:, :ctx_len]
x_attn_out = attn[:, ctx_len:]
x = self.x_block.post_attention(x_attn_out, *x_intermediate)
if not self.context_block.pre_only:
context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate)
else:
context = None
return context, x
def forward(self, *args, **kwargs):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
else:
return self._forward(*args, **kwargs)
class MMDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size: int = 32,
patch_size: int = 2,
in_channels: int = 4,
depth: int = 28,
# hidden_size: Optional[int] = None,
# num_heads: Optional[int] = None,
mlp_ratio: float = 4.0,
learn_sigma: bool = False,
adm_in_channels: Optional[int] = None,
context_embedder_in_features: Optional[int] = None,
context_embedder_out_features: Optional[int] = None,
use_checkpoint: bool = False,
register_length: int = 0,
attn_mode: str = "torch",
rmsnorm: bool = False,
scale_mod_only: bool = False,
swiglu: bool = False,
out_channels: Optional[int] = None,
pos_embed_scaling_factor: Optional[float] = None,
pos_embed_offset: Optional[float] = None,
pos_embed_max_size: Optional[int] = None,
num_patches=None,
qk_norm: Optional[str] = None,
qkv_bias: bool = True,
model_type: str = "sd3m",
):
super().__init__()
self._model_type = model_type
self.learn_sigma = learn_sigma
self.in_channels = in_channels
default_out_channels = in_channels * 2 if learn_sigma else in_channels
self.out_channels = default(out_channels, default_out_channels)
self.patch_size = patch_size
self.pos_embed_scaling_factor = pos_embed_scaling_factor
self.pos_embed_offset = pos_embed_offset
self.pos_embed_max_size = pos_embed_max_size
self.gradient_checkpointing = use_checkpoint
# hidden_size = default(hidden_size, 64 * depth)
# num_heads = default(num_heads, hidden_size // 64)
# apply magic --> this defines a head_size of 64
self.hidden_size = 64 * depth
num_heads = depth
self.num_heads = num_heads
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
self.hidden_size,
bias=True,
strict_img_size=self.pos_embed_max_size is None,
)
self.t_embedder = TimestepEmbedding(self.hidden_size)
self.y_embedder = None
if adm_in_channels is not None:
assert isinstance(adm_in_channels, int)
self.y_embedder = Embedder(adm_in_channels, self.hidden_size)
if context_embedder_in_features is not None:
self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features)
else:
self.context_embedder = nn.Identity()
self.register_length = register_length
if self.register_length > 0:
self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size))
# num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
# just use a buffer already
if num_patches is not None:
self.register_buffer(
"pos_embed",
torch.empty(1, num_patches, self.hidden_size),
)
else:
self.pos_embed = None
self.use_checkpoint = use_checkpoint
self.joint_blocks = nn.ModuleList(
[
MMDiTBlock(
self.hidden_size,
num_heads,
mlp_ratio=mlp_ratio,
attn_mode=attn_mode,
qkv_bias=qkv_bias,
pre_only=i == depth - 1,
rmsnorm=rmsnorm,
scale_mod_only=scale_mod_only,
swiglu=swiglu,
qk_norm=qk_norm,
)
for i in range(depth)
]
)
for block in self.joint_blocks:
block.gradient_checkpointing = use_checkpoint
self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels)
# self.initialize_weights()
@property
def model_type(self):
return self._model_type
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True
for block in self.joint_blocks:
block.enable_gradient_checkpointing()
def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
for block in self.joint_blocks:
block.disable_gradient_checkpointing()
def initialize_weights(self):
# TODO: Init context_embedder?
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize (and freeze) pos_embed by sin-cos embedding
if self.pos_embed is not None:
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.pos_embed.shape[-2] ** 0.5),
scaling_factor=self.pos_embed_scaling_factor,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.x_embedder.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.x_embedder.proj.bias, 0)
if getattr(self, "y_embedder", None) is not None:
nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.joint_blocks:
nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0)
nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def cropped_pos_embed(self, h, w, device=None):
p = self.x_embedder.patch_size
# patched size
h = (h + 1) // p
w = (w + 1) // p
if self.pos_embed is None:
return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device)
assert self.pos_embed_max_size is not None
assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size)
assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size)
top = (self.pos_embed_max_size - h) // 2
left = (self.pos_embed_max_size - w) // 2
spatial_pos_embed = self.pos_embed.reshape(
1,
self.pos_embed_max_size,
self.pos_embed_max_size,
self.pos_embed.shape[-1],
)
spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :]
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
return spatial_pos_embed
def enable_block_swap(self, num_blocks: int):
self.blocks_to_swap = num_blocks
n = 1 # async block swap. 1 is enough
self.thread_pool = ThreadPoolExecutor(max_workers=n)
def move_to_device_except_swap_blocks(self, device: torch.device):
# assume model is on cpu
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.to(device)
if self.blocks_to_swap:
self.joint_blocks = save_blocks
def prepare_block_swap_before_forward(self):
# make: first n blocks are on cuda, and last n blocks are on cpu
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
# raise ValueError("Block swap is not enabled.")
return
num_blocks = len(self.joint_blocks)
for i in range(num_blocks - self.blocks_to_swap):
self.joint_blocks[i].to(self.device)
for i in range(num_blocks - self.blocks_to_swap, num_blocks):
self.joint_blocks[i].to("cpu")
clean_memory_on_device(self.device)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N, D) tensor of class labels
"""
B, C, H, W = x.shape
x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device).to(dtype=x.dtype)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)
c = c + y # (N, D)
if context is not None:
context = self.context_embedder(context)
if self.register_length > 0:
context = torch.cat(
(
einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]),
default(context, torch.Tensor([]).type_as(x)),
),
1,
)
if not self.blocks_to_swap:
for block in self.joint_blocks:
context, x = block(context, x, c)
else:
futures = {}
def submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda):
def move_blocks(bidx_to_cpu, block_to_cpu, bidx_to_cuda, block_to_cuda):
# print(f"Moving {bidx_to_cpu} to cpu.")
block_to_cpu.to("cpu", non_blocking=True)
torch.cuda.empty_cache()
# print(f"Moving {bidx_to_cuda} to cuda.")
block_to_cuda.to(self.device, non_blocking=True)
torch.cuda.synchronize()
# print(f"Block move done. {bidx_to_cpu} to cpu, {bidx_to_cuda} to cuda.")
return block_idx_to_cpu, block_idx_to_cuda
block_to_cpu = self.joint_blocks[block_idx_to_cpu]
block_to_cuda = self.joint_blocks[block_idx_to_cuda]
# print(f"Submit move blocks. {block_idx_to_cpu} to cpu, {block_idx_to_cuda} to cuda.")
return self.thread_pool.submit(move_blocks, block_idx_to_cpu, block_to_cpu, block_idx_to_cuda, block_to_cuda)
def wait_for_blocks_move(block_idx, ftrs):
if block_idx not in ftrs:
return
# print(f"Waiting for move blocks: {block_idx}")
# start_time = time.perf_counter()
ftr = ftrs.pop(block_idx)
ftr.result()
# torch.cuda.synchronize()
# print(f"Move blocks took {time.perf_counter() - start_time:.2f} seconds")
for block_idx, block in enumerate(self.joint_blocks):
wait_for_blocks_move(block_idx, futures)
context, x = block(context, x, c)
if block_idx < self.blocks_to_swap:
block_idx_to_cpu = block_idx
block_idx_to_cuda = len(self.joint_blocks) - self.blocks_to_swap + block_idx
future = submit_move_blocks(block_idx_to_cpu, block_idx_to_cuda)
futures[block_idx_to_cuda] = future
x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify
return x[:, :, :H, :W]
def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT:
mmdit = MMDiT(
input_size=None,
pos_embed_max_size=params.pos_embed_max_size,
patch_size=params.patch_size,
in_channels=16,
adm_in_channels=params.adm_in_channels,
context_embedder_in_features=params.context_embedder_in_features,
context_embedder_out_features=params.context_embedder_out_features,
depth=params.depth,
mlp_ratio=4,
qk_norm=params.qk_norm,
num_patches=params.num_patches,
attn_mode=attn_mode,
model_type=params.model_type,
)
return mmdit
# endregion
# region VAE
VAE_SCALE_FACTOR = 1.5305
VAE_SHIFT_FACTOR = 0.0609
def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
class ResnetBlock(torch.nn.Module):
def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = Normalize(in_channels, dtype=dtype, device=device)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
self.norm2 = Normalize(out_channels, dtype=dtype, device=device)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
if self.in_channels != self.out_channels:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device
)
else:
self.nin_shortcut = None
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
hidden = x
hidden = self.norm1(hidden)
hidden = self.swish(hidden)
hidden = self.conv1(hidden)
hidden = self.norm2(hidden)
hidden = self.swish(hidden)
hidden = self.conv2(hidden)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + hidden
class AttnBlock(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.norm = Normalize(in_channels, dtype=dtype, device=device)
self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device)
def forward(self, x):
hidden = self.norm(x)
q = self.q(hidden)
k = self.k(hidden)
v = self.v(hidden)
b, c, h, w = q.shape
q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v))
hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
hidden = self.proj_out(hidden)
return x + hidden
class Downsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device)
def forward(self, x):
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(torch.nn.Module):
def __init__(self, in_channels, dtype=torch.float32, device=None):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
def forward(self, x):
org_dtype = x.dtype
if x.dtype == torch.bfloat16:
x = x.to(torch.float32)
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if x.dtype != org_dtype:
x = x.to(org_dtype)
x = self.conv(x)
return x
class VAEEncoder(torch.nn.Module):
def __init__(
self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = torch.nn.ModuleList()
for i_level in range(self.num_resolutions):
block = torch.nn.ModuleList()
attn = torch.nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
down = torch.nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, dtype=dtype, device=device)
self.down.append(down)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, x):
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = self.swish(h)
h = self.conv_out(h)
return h
class VAEDecoder(torch.nn.Module):
def __init__(
self,
ch=128,
out_ch=3,
ch_mult=(1, 2, 4, 4),
num_res_blocks=2,
resolution=256,
z_channels=16,
dtype=torch.float32,
device=None,
):
super().__init__()
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
# middle
self.mid = torch.nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device)
# upsampling
self.up = torch.nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = torch.nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device))
block_in = block_out
up = torch.nn.Module()
up.block = block
if i_level != 0:
up.upsample = Upsample(block_in, dtype=dtype, device=device)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, dtype=dtype, device=device)
self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device)
self.swish = torch.nn.SiLU(inplace=True)
def forward(self, z):
# z to block_in
hidden = self.conv_in(z)
# middle
hidden = self.mid.block_1(hidden)
hidden = self.mid.attn_1(hidden)
hidden = self.mid.block_2(hidden)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hidden = self.up[i_level].block[i_block](hidden)
if i_level != 0:
hidden = self.up[i_level].upsample(hidden)
# end
hidden = self.norm_out(hidden)
hidden = self.swish(hidden)
hidden = self.conv_out(hidden)
return hidden
class SDVAE(torch.nn.Module):
def __init__(self, dtype=torch.float32, device=None):
super().__init__()
self.encoder = VAEEncoder(dtype=dtype, device=device)
self.decoder = VAEDecoder(dtype=dtype, device=device)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
# @torch.autocast("cuda", dtype=torch.float16)
def decode(self, latent):
return self.decoder(latent)
# @torch.autocast("cuda", dtype=torch.float16)
def encode(self, image):
hidden = self.encoder(image)
mean, logvar = torch.chunk(hidden, 2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
@staticmethod
def process_in(latent):
return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR
@staticmethod
def process_out(latent):
return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR
# endregion