mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
feat: use unified attention module, add wrapper for state dict compatibility
This commit is contained in:
@@ -13,7 +13,7 @@ import torch.nn.functional as F
|
||||
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
|
||||
from library import custom_offloading_utils
|
||||
from library import custom_offloading_utils, attention
|
||||
from library.device_utils import clean_memory_on_device
|
||||
|
||||
|
||||
@@ -123,24 +123,24 @@ def unsloth_checkpoint(function, *args):
|
||||
return UnslothOffloadedGradientCheckpointer.apply(function, *args)
|
||||
|
||||
|
||||
# Flash Attention support
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
|
||||
FLASH_ATTN_AVAILABLE = True
|
||||
except ImportError:
|
||||
_flash_attn_func = None
|
||||
FLASH_ATTN_AVAILABLE = False
|
||||
# # Flash Attention support
|
||||
# try:
|
||||
# from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
|
||||
# FLASH_ATTN_AVAILABLE = True
|
||||
# except ImportError:
|
||||
# _flash_attn_func = None
|
||||
# FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
|
||||
def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes multi-head attention using Flash Attention.
|
||||
# def flash_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
|
||||
# """Computes multi-head attention using Flash Attention.
|
||||
|
||||
Input format: (batch, seq_len, n_heads, head_dim)
|
||||
Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
|
||||
"""
|
||||
# flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
|
||||
out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
|
||||
return rearrange(out, "b s h d -> b s (h d)")
|
||||
# Input format: (batch, seq_len, n_heads, head_dim)
|
||||
# Output format: (batch, seq_len, n_heads * head_dim) — matches torch_attention_op output.
|
||||
# """
|
||||
# # flash_attn_func expects (B, S, H, D) and returns (B, S, H, D)
|
||||
# out = _flash_attn_func(q_B_S_H_D, k_B_S_H_D, v_B_S_H_D)
|
||||
# return rearrange(out, "b s h d -> b s (h d)")
|
||||
|
||||
|
||||
from .utils import setup_logging
|
||||
@@ -399,18 +399,23 @@ class Attention(nn.Module):
|
||||
|
||||
return q, k, v
|
||||
|
||||
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
# def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
# result = self.attn_op(q, k, v) # [B, S, H, D]
|
||||
# return self.output_dropout(self.output_proj(result))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
rope_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
|
||||
return self.compute_attention(q, k, v)
|
||||
# return self.compute_attention(q, k, v)
|
||||
qkv = [q,k,v]
|
||||
del q, k, v
|
||||
result = attention.attention(qkv, attn_params=attn_params)
|
||||
return self.output_dropout(self.output_proj(result))
|
||||
|
||||
|
||||
# Positional Embeddings
|
||||
@@ -904,6 +909,7 @@ class Block(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -954,6 +960,7 @@ class Block(nn.Module):
|
||||
result = rearrange(
|
||||
self.self_attn(
|
||||
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
|
||||
attn_params,
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
@@ -967,6 +974,7 @@ class Block(nn.Module):
|
||||
result = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(normalized_x, "b t h w d -> b (t h w) d"),
|
||||
attn_params,
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
),
|
||||
@@ -987,6 +995,7 @@ class Block(nn.Module):
|
||||
x_B_T_H_W_D: torch.Tensor,
|
||||
emb_B_T_D: torch.Tensor,
|
||||
crossattn_emb: torch.Tensor,
|
||||
attn_params: attention.AttentionParams,
|
||||
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
|
||||
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
@@ -996,7 +1005,7 @@ class Block(nn.Module):
|
||||
# Unsloth: async non-blocking CPU RAM offload (fastest offload method)
|
||||
return unsloth_checkpoint(
|
||||
self._forward,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
)
|
||||
elif self.cpu_offload_checkpointing:
|
||||
@@ -1012,7 +1021,7 @@ class Block(nn.Module):
|
||||
|
||||
return torch_checkpoint(
|
||||
create_custom_forward(self._forward),
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
@@ -1020,13 +1029,13 @@ class Block(nn.Module):
|
||||
# Standard gradient checkpointing (no offload)
|
||||
return torch_checkpoint(
|
||||
self._forward,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
return self._forward(
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb,
|
||||
x_B_T_H_W_D, emb_B_T_D, crossattn_emb, attn_params,
|
||||
rope_emb_L_1_1_D, adaln_lora_B_T_3D, extra_per_block_pos_emb,
|
||||
)
|
||||
|
||||
@@ -1069,6 +1078,8 @@ class MiniTrainDIT(nn.Module):
|
||||
extra_t_extrapolation_ratio: float = 1.0,
|
||||
rope_enable_fps_modulation: bool = True,
|
||||
use_llm_adapter: bool = False,
|
||||
attn_mode: str = "torch",
|
||||
split_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_img_h = max_img_h
|
||||
@@ -1097,6 +1108,9 @@ class MiniTrainDIT(nn.Module):
|
||||
self.rope_enable_fps_modulation = rope_enable_fps_modulation
|
||||
self.use_llm_adapter = use_llm_adapter
|
||||
|
||||
self.attn_mode = attn_mode
|
||||
self.split_attn = split_attn
|
||||
|
||||
# Block swap support
|
||||
self.blocks_to_swap = None
|
||||
self.offloader: Optional[custom_offloading_utils.ModelOffloader] = None
|
||||
@@ -1170,17 +1184,17 @@ class MiniTrainDIT(nn.Module):
|
||||
return next(self.parameters()).device
|
||||
|
||||
|
||||
def set_flash_attn(self, use_flash_attn: bool):
|
||||
"""Toggle flash attention for all DiT blocks (self-attn + cross-attn).
|
||||
# def set_flash_attn(self, use_flash_attn: bool):
|
||||
# """Toggle flash attention for all DiT blocks (self-attn + cross-attn).
|
||||
|
||||
LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
|
||||
"""
|
||||
if use_flash_attn and not FLASH_ATTN_AVAILABLE:
|
||||
raise ImportError("flash_attn package is required for --flash_attn but is not installed")
|
||||
attn_op = flash_attention_op if use_flash_attn else torch_attention_op
|
||||
for block in self.blocks:
|
||||
block.self_attn.attn_op = attn_op
|
||||
block.cross_attn.attn_op = attn_op
|
||||
# LLM Adapter attention is NOT affected (it uses attention masks incompatible with flash_attn).
|
||||
# """
|
||||
# if use_flash_attn and not FLASH_ATTN_AVAILABLE:
|
||||
# raise ImportError("flash_attn package is required for --flash_attn but is not installed")
|
||||
# attn_op = flash_attention_op if use_flash_attn else torch_attention_op
|
||||
# for block in self.blocks:
|
||||
# block.self_attn.attn_op = attn_op
|
||||
# block.cross_attn.attn_op = attn_op
|
||||
|
||||
def build_patch_embed(self) -> None:
|
||||
in_channels = self.in_channels + 1 if self.concat_padding_mask else self.in_channels
|
||||
@@ -1337,6 +1351,8 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb,
|
||||
}
|
||||
|
||||
attn_params= attention.AttentionParams.create_attention_params(self.attn_mode, self.split_attn)
|
||||
|
||||
for block_idx, block in enumerate(self.blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader.wait_for_block(block_idx)
|
||||
@@ -1345,6 +1361,7 @@ class MiniTrainDIT(nn.Module):
|
||||
x_B_T_H_W_D,
|
||||
t_embedding_B_T_D,
|
||||
crossattn_emb,
|
||||
attn_params,
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
@@ -1563,6 +1580,42 @@ class LLMAdapter(nn.Module):
|
||||
return self.norm(self.out_proj(x))
|
||||
|
||||
|
||||
class Anima(nn.Module):
|
||||
"""
|
||||
Wrapper class for the MiniTrainDIT and LLM Adapter.
|
||||
"""
|
||||
LATENT_CHANNELS = 16
|
||||
|
||||
def __init__(self, dit_config: dict):
|
||||
super().__init__()
|
||||
self.net = MiniTrainDIT(**dit_config)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.net.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.net.dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
fps: Optional[torch.Tensor] = None,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.net(x, timesteps, context, fps=fps, padding_mask=padding_mask, **kwargs)
|
||||
|
||||
def preprocess_text_embeds(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None):
|
||||
if target_input_ids is not None:
|
||||
return self.net.llm_adapter(source_hidden_states, target_input_ids, target_attention_mask=target_attention_mask,
|
||||
source_attention_mask=source_attention_mask)
|
||||
else:
|
||||
return source_hidden_states
|
||||
|
||||
# VAE Wrapper
|
||||
|
||||
# VAE normalization constants
|
||||
|
||||
Reference in New Issue
Block a user