feat: use unified attention module, add wrapper for state dict compatibility

This commit is contained in:
kohya-ss
2026-02-08 12:16:00 +09:00
parent 10445ff660
commit 44b8d79577

View File

@@ -13,7 +13,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from library import custom_offloading_utils
from library import custom_offloading_utils, attention
from library.device_utils import clean_memory_on_device
@@ -123,24 +123,24 @@ def unsloth_checkpoint(function, *args):
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
# Flash Attention support
try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
_flash_attn_func = None
FLASH_ATTN_AVAILABLE = False
# # Flash Attention support
# try:
# from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
# FLASH_ATTN_AVAILABLE = True
# except ImportError:
# _flash_attn_func = None
# FLASH_ATTN_AVAILABLE = False
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
"""Computes multi-head attention using Flash Attention.
# def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
# """Computes multi-head attention using Flash Attention.
Input format: (batch, seq_len, n_heads, head_dim)
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
"""
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
return rearrange(out, "b s h d -> b s (h d)")
# Input format: (batch, seq_len, n_heads, head_dim)
# Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
# """
# # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
# out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
# return rearrange(out, "b s h d -> b s (h d)")
from .utils import setup_logging
@@ -399,18 +399,23 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
# result = self.attn_op(q, k, v) # [B, S, H, D]
# return self.output_dropout(self.output_proj(result))
def forward(
self,
x: torch.Tensor,
attn_params: attention.AttentionParams,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
# return self.compute_attention(q, k, v)
qkv = [q,k,v]
del q, k, v
result = attention.attention(qkv, attn_params=attn_params)
return self.output_dropout(self.output_proj(result))
# Positional Embeddings
@@ -904,6 +909,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -954,6 +960,7 @@ class Block(nn.Module):
result = rearrange(
self.self_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
None,
rope_emb=rope_emb_L_1_1_D,
),
@@ -967,6 +974,7 @@ class Block(nn.Module):
result = rearrange(
self.cross_attn(
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
attn_params,
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
),
@@ -987,6 +995,7 @@ class Block(nn.Module):
x_B_T_H_W_D: torch.Tensor,
emb_B_T_D: torch.Tensor,
crossattn_emb: torch.Tensor,
attn_params: attention.AttentionParams,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
@@ -996,7 +1005,7 @@ class Block(nn.Module):
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
return unsloth_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
)
elif self.cpu_offload_checkpointing:
@@ -1012,7 +1021,7 @@ class Block(nn.Module):
return torch_checkpoint(
create_custom_forward(self._forward),
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
use_reentrant=False,
)
@@ -1020,13 +1029,13 @@ class Block(nn.Module):
# Standard gradient checkpointing (no offload)
return torch_checkpoint(
self._forward,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
use_reentrant=False,
)
else:
return self._forward(
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
)
@@ -1069,6 +1078,8 @@ class MiniTrainDIT(nn.Module):
extra_t_extrapolation_ratio: float = 1.0,
rope_enable_fps_modulation: bool = True,
use_llm_adapter: bool = False,
attn_mode: str = "torch",
split_attn: bool = False,
) -> None:
super().__init__()
self.max_img_h = max_img_h
@@ -1097,6 +1108,9 @@ class MiniTrainDIT(nn.Module):
self.rope_enable_fps_modulation = rope_enable_fps_modulation
self.use_llm_adapter = use_llm_adapter
self.attn_mode = attn_mode
self.split_attn = split_attn
# Block swap support
self.blocks_to_swap = None
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
@@ -1170,17 +1184,17 @@ class MiniTrainDIT(nn.Module):
return next(self.parameters()).device
def set_flash_attn(self, use_flash_attn: bool):
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
# def set_flash_attn(self, use_flash_attn: bool):
# """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
"""
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
for block in self.blocks:
block.self_attn.attn_op = attn_op
block.cross_attn.attn_op = attn_op
# LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
# """
# if use_flash_attn and not FLASH_ATTN_AVAILABLE:
# raise ImportError("flash_attn package is required for --flash_attn but is not installed")
# attn_op = flash_attention_op if use_flash_attn else torch_attention_op
# for block in self.blocks:
# block.self_attn.attn_op = attn_op
# block.cross_attn.attn_op = attn_op
def build_patch_embed(self) -> None:
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
@@ -1337,6 +1351,8 @@ class MiniTrainDIT(nn.Module):
"extra_per_block_pos_emb": extra_pos_emb,
}
attn_params= attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
for block_idx, block in enumerate(self.blocks):
if self.blocks_to_swap:
self.offloader.wait_for_block(block_idx)
@@ -1345,6 +1361,7 @@ class MiniTrainDIT(nn.Module):
x_B_T_H_W_D,
t_embedding_B_T_D,
crossattn_emb,
attn_params,
**block_kwargs,
)
@@ -1563,6 +1580,42 @@ class LLMAdapter(nn.Module):
return self.norm(self.out_proj(x))
class Anima(nn.Module):
"""
Wrapper class for the MiniTrainDIT and LLM Adapter.
"""
LATENT_CHANNELS = 16
def __init__(self, dit_config: dict):
super().__init__()
self.net = MiniTrainDIT(**dit_config)
@property
def device(self):
return self.net.device
@property
def dtype(self):
return self.net.dtype
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
context: torch.Tensor,
fps: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
def preprocess_text_embeds(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
if target_input_ids is not None:
return self.net.llm_adapter(source_hidden_states, target_input_ids, target_attention_mask=target_attention_mask,
source_attention_mask=source_attention_mask)
else:
return source_hidden_states
# VAE Wrapper
# VAE normalization constants