mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
feat: add multi backend attention and related update for HI2.1 models and scripts
This commit is contained in:
@@ -126,7 +126,8 @@ accelerate launch --num_cpu_threads_per_process 1 hunyuan_image_train_network.py
|
|||||||
--learning_rate=1e-4 \
|
--learning_rate=1e-4 \
|
||||||
--optimizer_type="AdamW8bit" \
|
--optimizer_type="AdamW8bit" \
|
||||||
--lr_scheduler="constant" \
|
--lr_scheduler="constant" \
|
||||||
--sdpa \
|
--attn_mode="torch" \
|
||||||
|
--split_attn \
|
||||||
--max_train_epochs=10 \
|
--max_train_epochs=10 \
|
||||||
--save_every_n_epochs=1 \
|
--save_every_n_epochs=1 \
|
||||||
--mixed_precision="bf16" \
|
--mixed_precision="bf16" \
|
||||||
@@ -175,6 +176,10 @@ The script adds HunyuanImage-2.1 specific arguments. For common arguments (like
|
|||||||
|
|
||||||
#### Memory/Speed Related
|
#### Memory/Speed Related
|
||||||
|
|
||||||
|
* `--attn_mode=<choice>`
|
||||||
|
- Specifies the attention implementation to use. Options are `torch`, `xformers`, `flash`, `sageattn`. Default is `torch` (use scaled dot product attention). Each library must be installed separately other than `torch`. If using `xformers`, also specify `--split_attn` if the batch size is more than 1.
|
||||||
|
* `--split_attn`
|
||||||
|
- Splits the batch during attention computation to process one item at a time, reducing VRAM usage by avoiding attention mask computation. Can improve speed when using `torch`. Required when using `xformers` with batch size greater than 1.
|
||||||
* `--fp8_scaled`
|
* `--fp8_scaled`
|
||||||
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option.
|
- Enables training the DiT model in scaled FP8 format. This can significantly reduce VRAM usage (can run with as little as 8GB VRAM when combined with `--blocks_to_swap`), but the training results may vary. This is a newer alternative to the unsupported `--fp8_base` option.
|
||||||
* `--fp8_vl`
|
* `--fp8_vl`
|
||||||
@@ -429,6 +434,7 @@ python hunyuan_image_minimal_inference.py \
|
|||||||
--vae "<path to hunyuan_image_2.1_vae_fp16.safetensors>" \
|
--vae "<path to hunyuan_image_2.1_vae_fp16.safetensors>" \
|
||||||
--lora_weight "<path to your trained LoRA>" \
|
--lora_weight "<path to your trained LoRA>" \
|
||||||
--lora_multiplier 1.0 \
|
--lora_multiplier 1.0 \
|
||||||
|
--attn_mode "torch" \
|
||||||
--prompt "A cute cartoon penguin in a snowy landscape" \
|
--prompt "A cute cartoon penguin in a snowy landscape" \
|
||||||
--image_size 2048 2048 \
|
--image_size 2048 2048 \
|
||||||
--infer_steps 50 \
|
--infer_steps 50 \
|
||||||
@@ -445,6 +451,8 @@ python hunyuan_image_minimal_inference.py \
|
|||||||
- `--guidance_scale`: CFG scale (default: 3.5)
|
- `--guidance_scale`: CFG scale (default: 3.5)
|
||||||
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
|
- `--flow_shift`: Flow matching shift parameter (default: 5.0)
|
||||||
|
|
||||||
|
`--split_attn` is not supported (since inference is done one at a time).
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>日本語</summary>
|
<summary>日本語</summary>
|
||||||
|
|
||||||
@@ -457,6 +465,8 @@ python hunyuan_image_minimal_inference.py \
|
|||||||
- `--guidance_scale`: CFGスケール(推奨: 3.5)
|
- `--guidance_scale`: CFGスケール(推奨: 3.5)
|
||||||
- `--flow_shift`: Flow Matchingシフトパラメータ(デフォルト: 5.0)
|
- `--flow_shift`: Flow Matchingシフトパラメータ(デフォルト: 5.0)
|
||||||
|
|
||||||
|
`--split_attn`はサポートされていません(1件ずつ推論するため)。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## 9. Related Tools / 関連ツール
|
## 9. Related Tools / 関連ツール
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ def parse_args() -> argparse.Namespace:
|
|||||||
"--attn_mode",
|
"--attn_mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="torch",
|
default="torch",
|
||||||
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
|
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "sdpa" for backward compatibility
|
||||||
help="attention mode",
|
help="attention mode",
|
||||||
)
|
)
|
||||||
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
||||||
@@ -130,6 +130,9 @@ def parse_args() -> argparse.Namespace:
|
|||||||
if args.lycoris and not lycoris_available:
|
if args.lycoris and not lycoris_available:
|
||||||
raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
|
raise ValueError("install lycoris: https://github.com/KohakuBlueleaf/LyCORIS")
|
||||||
|
|
||||||
|
if args.attn_mode == "sdpa":
|
||||||
|
args.attn_mode = "torch" # backward compatibility
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
@@ -265,7 +268,7 @@ def load_dit_model(
|
|||||||
device,
|
device,
|
||||||
args.dit,
|
args.dit,
|
||||||
args.attn_mode,
|
args.attn_mode,
|
||||||
False,
|
True, # enable split_attn to trim masked tokens
|
||||||
loading_device,
|
loading_device,
|
||||||
loading_weight_dtype,
|
loading_weight_dtype,
|
||||||
args.fp8_scaled and not args.lycoris,
|
args.fp8_scaled and not args.lycoris,
|
||||||
|
|||||||
@@ -379,18 +379,19 @@ class HunyuanImageNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
|
|
||||||
loading_dtype = None if args.fp8_scaled else weight_dtype
|
loading_dtype = None if args.fp8_scaled else weight_dtype
|
||||||
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
loading_device = "cpu" if self.is_swapping_blocks else accelerator.device
|
||||||
split_attn = True
|
|
||||||
|
|
||||||
attn_mode = "torch"
|
attn_mode = "torch"
|
||||||
if args.xformers:
|
if args.xformers:
|
||||||
attn_mode = "xformers"
|
attn_mode = "xformers"
|
||||||
logger.info("xformers is enabled for attention")
|
if args.attn_mode is not None:
|
||||||
|
attn_mode = args.attn_mode
|
||||||
|
|
||||||
|
logger.info(f"Loading DiT model with attn_mode: {attn_mode}, split_attn: {args.split_attn}, fp8_scaled: {args.fp8_scaled}")
|
||||||
model = hunyuan_image_models.load_hunyuan_image_model(
|
model = hunyuan_image_models.load_hunyuan_image_model(
|
||||||
accelerator.device,
|
accelerator.device,
|
||||||
args.pretrained_model_name_or_path,
|
args.pretrained_model_name_or_path,
|
||||||
attn_mode,
|
attn_mode,
|
||||||
split_attn,
|
args.split_attn,
|
||||||
loading_device,
|
loading_device,
|
||||||
loading_dtype,
|
loading_dtype,
|
||||||
args.fp8_scaled,
|
args.fp8_scaled,
|
||||||
@@ -674,6 +675,19 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
|
help="Enable tiling for VAE decoding and encoding / VAEデコーディングとエンコーディングのタイルを有効にする",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attn_mode",
|
||||||
|
choices=["torch", "xformers", "flash", "sageattn", "sdpa"], # "sdpa" is for backward compatibility
|
||||||
|
default=None,
|
||||||
|
help="Attention implementation to use. Default is None (torch). xformers requires --split_attn. sageattn does not support training (inference only). This option overrides --xformers or --sdpa."
|
||||||
|
" / 使用するAttentionの実装。デフォルトはNone(torch)です。xformersは--split_attnの指定が必要です。sageattnはトレーニングをサポートしていません(推論のみ)。このオプションは--xformersまたは--sdpaを上書きします。",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--split_attn",
|
||||||
|
action="store_true",
|
||||||
|
help="split attention computation to reduce memory usage / メモリ使用量を減らすためにattention時にバッチを分割する",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@@ -684,5 +698,8 @@ if __name__ == "__main__":
|
|||||||
train_util.verify_command_line_training_args(args)
|
train_util.verify_command_line_training_args(args)
|
||||||
args = train_util.read_config_from_file(args, parser)
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
if args.attn_mode == "sdpa":
|
||||||
|
args.attn_mode = "torch" # backward compatibility
|
||||||
|
|
||||||
trainer = HunyuanImageNetworkTrainer()
|
trainer = HunyuanImageNetworkTrainer()
|
||||||
trainer.train(args)
|
trainer.train(args)
|
||||||
|
|||||||
@@ -1,18 +1,88 @@
|
|||||||
|
# Unified attention function supporting various implementations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
from flash_attn.flash_attn_interface import _flash_attn_forward
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
except ImportError:
|
||||||
|
flash_attn = None
|
||||||
|
flash_attn_varlen_func = None
|
||||||
|
_flash_attn_forward = None
|
||||||
|
flash_attn_func = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sageattention import sageattn_varlen, sageattn
|
||||||
|
except ImportError:
|
||||||
|
sageattn_varlen = None
|
||||||
|
sageattn = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers.ops as xops
|
import xformers.ops as xops
|
||||||
except ImportError:
|
except ImportError:
|
||||||
xops = None
|
xops = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AttentionParams:
|
||||||
|
attn_mode: Optional[str] = None
|
||||||
|
split_attn: bool = False
|
||||||
|
img_len: Optional[int] = None
|
||||||
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
|
seqlens: Optional[torch.Tensor] = None
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None
|
||||||
|
max_seqlen: Optional[int] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_attention_params(attn_mode: Optional[str], split_attn: bool) -> "AttentionParams":
|
||||||
|
return AttentionParams(attn_mode, split_attn)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_attention_params_from_mask(
|
||||||
|
attn_mode: Optional[str], split_attn: bool, img_len: Optional[int], attention_mask: Optional[torch.Tensor]
|
||||||
|
) -> "AttentionParams":
|
||||||
|
if attention_mask is None:
|
||||||
|
# No attention mask provided: assume all tokens are valid
|
||||||
|
return AttentionParams(attn_mode, split_attn, None, None, None, None, None)
|
||||||
|
else:
|
||||||
|
# Note: attention_mask is only for text tokens, not including image tokens
|
||||||
|
seqlens = attention_mask.sum(dim=1).to(torch.int32) + img_len # [B]
|
||||||
|
max_seqlen = attention_mask.shape[1] + img_len
|
||||||
|
|
||||||
|
if split_attn:
|
||||||
|
# cu_seqlens is not needed for split attention
|
||||||
|
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, None, max_seqlen)
|
||||||
|
|
||||||
|
# Convert attention mask to cumulative sequence lengths for flash attention
|
||||||
|
batch_size = attention_mask.shape[0]
|
||||||
|
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=attention_mask.device)
|
||||||
|
for i in range(batch_size):
|
||||||
|
cu_seqlens[2 * i + 1] = i * max_seqlen + seqlens[i] # end of valid tokens for query
|
||||||
|
cu_seqlens[2 * i + 2] = (i + 1) * max_seqlen # end of all tokens for query
|
||||||
|
|
||||||
|
# Expand attention mask to include image tokens
|
||||||
|
attention_mask = torch.nn.functional.pad(attention_mask, (img_len, 0), value=1) # [B, img_len + L]
|
||||||
|
|
||||||
|
if attn_mode == "xformers":
|
||||||
|
seqlens_list = seqlens.cpu().tolist()
|
||||||
|
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||||
|
seqlens_list, seqlens_list, device=attention_mask.device
|
||||||
|
)
|
||||||
|
elif attn_mode == "torch":
|
||||||
|
attention_mask = attention_mask[:, None, None, :].to(torch.bool) # [B, 1, 1, img_len + L]
|
||||||
|
|
||||||
|
return AttentionParams(attn_mode, split_attn, img_len, attention_mask, seqlens, cu_seqlens, max_seqlen)
|
||||||
|
|
||||||
|
|
||||||
def attention(
|
def attention(
|
||||||
qkv_or_q: Union[torch.Tensor, list],
|
qkv_or_q: Union[torch.Tensor, list],
|
||||||
k: Optional[torch.Tensor] = None,
|
k: Optional[torch.Tensor] = None,
|
||||||
v: Optional[torch.Tensor] = None,
|
v: Optional[torch.Tensor] = None,
|
||||||
seq_lens: Optional[list[int]] = None,
|
attn_params: Optional[AttentionParams] = None,
|
||||||
attn_mode: str = "torch",
|
|
||||||
drop_rate: float = 0.0,
|
drop_rate: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -25,8 +95,7 @@ def attention(
|
|||||||
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
qkv_or_q: Query tensor [B, L, H, D]. or list of such tensors.
|
||||||
k: Key tensor [B, L, H, D].
|
k: Key tensor [B, L, H, D].
|
||||||
v: Value tensor [B, L, H, D].
|
v: Value tensor [B, L, H, D].
|
||||||
seq_lens: Valid sequence length for each batch element.
|
attn_param: Attention parameters including mask and sequence lengths.
|
||||||
attn_mode: Attention implementation ("torch" or "sageattn").
|
|
||||||
drop_rate: Attention dropout rate.
|
drop_rate: Attention dropout rate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -34,53 +103,158 @@ def attention(
|
|||||||
"""
|
"""
|
||||||
if isinstance(qkv_or_q, list):
|
if isinstance(qkv_or_q, list):
|
||||||
q, k, v = qkv_or_q
|
q, k, v = qkv_or_q
|
||||||
|
q: torch.Tensor = q
|
||||||
qkv_or_q.clear()
|
qkv_or_q.clear()
|
||||||
del qkv_or_q
|
del qkv_or_q
|
||||||
else:
|
else:
|
||||||
q = qkv_or_q
|
q: torch.Tensor = qkv_or_q
|
||||||
del qkv_or_q
|
del qkv_or_q
|
||||||
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
|
assert k is not None and v is not None, "k and v must be provided if qkv_or_q is a tensor"
|
||||||
if seq_lens is None:
|
if attn_params is None:
|
||||||
seq_lens = [q.shape[1]] * q.shape[0]
|
attn_params = AttentionParams.create_attention_params("torch", False)
|
||||||
|
|
||||||
|
# If split attn is False, attention mask is provided and all sequence lengths are same, we can trim the sequence
|
||||||
|
seqlen_trimmed = False
|
||||||
|
if not attn_params.split_attn and attn_params.attention_mask is not None and attn_params.seqlens is not None:
|
||||||
|
if torch.all(attn_params.seqlens == attn_params.seqlens[0]):
|
||||||
|
seqlen = attn_params.seqlens[0].item()
|
||||||
|
q = q[:, :seqlen]
|
||||||
|
k = k[:, :seqlen]
|
||||||
|
v = v[:, :seqlen]
|
||||||
|
max_seqlen = attn_params.max_seqlen
|
||||||
|
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, False) # do not in-place modify
|
||||||
|
attn_params.max_seqlen = max_seqlen # keep max_seqlen for padding
|
||||||
|
seqlen_trimmed = True
|
||||||
|
|
||||||
# Determine tensor layout based on attention implementation
|
# Determine tensor layout based on attention implementation
|
||||||
if attn_mode == "torch" or attn_mode == "sageattn":
|
if attn_params.attn_mode == "torch" or (
|
||||||
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA
|
attn_params.attn_mode == "sageattn" and (attn_params.split_attn or attn_params.cu_seqlens is None)
|
||||||
|
):
|
||||||
|
transpose_fn = lambda x: x.transpose(1, 2) # [B, H, L, D] for SDPA and sageattn with fixed length
|
||||||
|
# pad on sequence length dimension
|
||||||
|
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, pad_to - x.shape[-2]), value=0)
|
||||||
else:
|
else:
|
||||||
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
|
transpose_fn = lambda x: x # [B, L, H, D] for other implementations
|
||||||
|
# pad on sequence length dimension
|
||||||
|
pad_fn = lambda x, pad_to: torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad_to - x.shape[-3]), value=0)
|
||||||
|
|
||||||
# Process each batch element with its valid sequence length
|
# Process each batch element with its valid sequence lengths
|
||||||
q_seq_len = q.shape[1]
|
if attn_params.split_attn:
|
||||||
q = [transpose_fn(q[i : i + 1, : seq_lens[i]]) for i in range(len(q))]
|
if attn_params.seqlens is None:
|
||||||
k = [transpose_fn(k[i : i + 1, : seq_lens[i]]) for i in range(len(k))]
|
# If no seqlens provided, assume all tokens are valid
|
||||||
v = [transpose_fn(v[i : i + 1, : seq_lens[i]]) for i in range(len(v))]
|
attn_params = AttentionParams.create_attention_params(attn_params.attn_mode, True) # do not in-place modify
|
||||||
|
attn_params.seqlens = torch.tensor([q.shape[1]] * q.shape[0], device=q.device)
|
||||||
|
attn_params.max_seqlen = q.shape[1]
|
||||||
|
q = [transpose_fn(q[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(q))]
|
||||||
|
k = [transpose_fn(k[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(k))]
|
||||||
|
v = [transpose_fn(v[i : i + 1, : attn_params.seqlens[i]]) for i in range(len(v))]
|
||||||
|
else:
|
||||||
|
q = transpose_fn(q)
|
||||||
|
k = transpose_fn(k)
|
||||||
|
v = transpose_fn(v)
|
||||||
|
|
||||||
if attn_mode == "torch":
|
if attn_params.attn_mode == "torch":
|
||||||
x = []
|
if attn_params.split_attn:
|
||||||
for i in range(len(q)):
|
x = []
|
||||||
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
for i in range(len(q)):
|
||||||
q[i] = None
|
x_i = torch.nn.functional.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate)
|
||||||
k[i] = None
|
q[i] = None
|
||||||
v[i] = None
|
k[i] = None
|
||||||
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, q_seq_len - x_i.shape[2]), value=0)) # Pad to max seq len, B, H, L, D
|
v[i] = None
|
||||||
x = torch.cat(x, dim=0)
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
|
||||||
del q, k, v
|
x = torch.cat(x, dim=0)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
elif attn_mode == "xformers":
|
else:
|
||||||
x = []
|
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_params.attention_mask, dropout_p=drop_rate)
|
||||||
for i in range(len(q)):
|
del q, k, v
|
||||||
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
|
|
||||||
q[i] = None
|
elif attn_params.attn_mode == "xformers":
|
||||||
k[i] = None
|
if attn_params.split_attn:
|
||||||
v[i] = None
|
x = []
|
||||||
x.append(torch.nn.functional.pad(x_i, (0, 0, 0, 0, 0, q_seq_len - x_i.shape[1]), value=0)) # B, L, H, D
|
for i in range(len(q)):
|
||||||
x = torch.cat(x, dim=0)
|
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate)
|
||||||
del q, k, v
|
q[i] = None
|
||||||
|
k[i] = None
|
||||||
|
v[i] = None
|
||||||
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
else:
|
||||||
|
x = xops.memory_efficient_attention(q, k, v, attn_bias=attn_params.attention_mask, p=drop_rate)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
elif attn_params.attn_mode == "sageattn":
|
||||||
|
if attn_params.split_attn:
|
||||||
|
x = []
|
||||||
|
for i in range(len(q)):
|
||||||
|
# HND seems to cause an error
|
||||||
|
x_i = sageattn(q[i], k[i], v[i]) # B, H, L, D. No dropout support
|
||||||
|
q[i] = None
|
||||||
|
k[i] = None
|
||||||
|
v[i] = None
|
||||||
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, H, L, D
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
del q, k, v
|
||||||
|
elif attn_params.cu_seqlens is None: # all tokens are valid
|
||||||
|
x = sageattn(q, k, v) # B, L, H, D. No dropout support
|
||||||
|
del q, k, v
|
||||||
|
else:
|
||||||
|
# Reshape to [(bxs), a, d]
|
||||||
|
batch_size, seqlen = q.shape[0], q.shape[1]
|
||||||
|
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
|
||||||
|
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
|
||||||
|
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
|
||||||
|
|
||||||
|
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv. No dropout support
|
||||||
|
x = sageattn_varlen(
|
||||||
|
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen
|
||||||
|
)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
|
||||||
|
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
|
||||||
|
|
||||||
|
elif attn_params.attn_mode == "flash":
|
||||||
|
if attn_params.split_attn:
|
||||||
|
x = []
|
||||||
|
for i in range(len(q)):
|
||||||
|
# HND seems to cause an error
|
||||||
|
x_i = flash_attn_func(q[i], k[i], v[i], drop_rate) # B, L, H, D
|
||||||
|
q[i] = None
|
||||||
|
k[i] = None
|
||||||
|
v[i] = None
|
||||||
|
x.append(pad_fn(x_i, attn_params.max_seqlen)) # B, L, H, D
|
||||||
|
x = torch.cat(x, dim=0)
|
||||||
|
del q, k, v
|
||||||
|
elif attn_params.cu_seqlens is None: # all tokens are valid
|
||||||
|
x = flash_attn_func(q, k, v, drop_rate) # B, L, H, D
|
||||||
|
del q, k, v
|
||||||
|
else:
|
||||||
|
# Reshape to [(bxs), a, d]
|
||||||
|
batch_size, seqlen = q.shape[0], q.shape[1]
|
||||||
|
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:]) # [B*L, H, D]
|
||||||
|
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:]) # [B*L, H, D]
|
||||||
|
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:]) # [B*L, H, D]
|
||||||
|
|
||||||
|
# Assume cu_seqlens_q == cu_seqlens_kv and max_seqlen_q == max_seqlen_kv
|
||||||
|
x = flash_attn_varlen_func(
|
||||||
|
q, k, v, attn_params.cu_seqlens, attn_params.cu_seqlens, attn_params.max_seqlen, attn_params.max_seqlen, drop_rate
|
||||||
|
)
|
||||||
|
del q, k, v
|
||||||
|
|
||||||
|
# Reshape x with shape [(bxs), a, d] to [b, s, a, d]
|
||||||
|
x = x.view(batch_size, seqlen, x.shape[-2], x.shape[-1]) # B, L, H, D
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Currently only PyTorch SDPA and xformers are implemented
|
# Currently only PyTorch SDPA and xformers are implemented
|
||||||
raise ValueError(f"Unsupported attention mode: {attn_mode}")
|
raise ValueError(f"Unsupported attention mode: {attn_params.attn_mode}")
|
||||||
|
|
||||||
x = transpose_fn(x) # [B, L, H, D]
|
x = transpose_fn(x) # [B, L, H, D]
|
||||||
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
|
x = x.reshape(x.shape[0], x.shape[1], -1) # [B, L, H*D]
|
||||||
|
|
||||||
|
if seqlen_trimmed:
|
||||||
|
x = torch.nn.functional.pad(x, (0, 0, 0, attn_params.max_seqlen - x.shape[1]), value=0) # pad back to max_seqlen
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import torch.nn as nn
|
|||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
from library import custom_offloading_utils
|
from library import custom_offloading_utils
|
||||||
|
from library.attention import AttentionParams
|
||||||
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
from library.fp8_optimization_utils import apply_fp8_monkey_patch
|
||||||
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
from library.lora_utils import load_safetensors_with_lora_and_fp8
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
@@ -50,7 +51,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
attn_mode: Attention implementation mode ("torch" or "sageattn").
|
attn_mode: Attention implementation mode ("torch" or "sageattn").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, attn_mode: str = "torch"):
|
def __init__(self, attn_mode: str = "torch", split_attn: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Fixed architecture parameters for HunyuanImage-2.1
|
# Fixed architecture parameters for HunyuanImage-2.1
|
||||||
@@ -80,6 +81,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
qk_norm_type: str = "rms" # RMS normalization type
|
qk_norm_type: str = "rms" # RMS normalization type
|
||||||
|
|
||||||
self.attn_mode = attn_mode
|
self.attn_mode = attn_mode
|
||||||
|
self.split_attn = split_attn
|
||||||
|
|
||||||
# ByT5 character-level text encoder mapping
|
# ByT5 character-level text encoder mapping
|
||||||
self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False)
|
self.byt5_in = ByT5Mapper(in_dim=1472, out_dim=2048, hidden_dim=2048, out_dim1=self.hidden_size, use_residual=False)
|
||||||
@@ -88,7 +90,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size)
|
self.img_in = PatchEmbed2D(self.patch_size, self.in_channels, self.hidden_size)
|
||||||
|
|
||||||
# Text token refinement with cross-attention
|
# Text token refinement with cross-attention
|
||||||
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2, attn_mode=self.attn_mode)
|
self.txt_in = SingleTokenRefiner(text_states_dim, self.hidden_size, self.heads_num, depth=2)
|
||||||
|
|
||||||
# Timestep embedding for diffusion process
|
# Timestep embedding for diffusion process
|
||||||
self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU)
|
self.time_in = TimestepEmbedder(self.hidden_size, nn.SiLU)
|
||||||
@@ -110,7 +112,6 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
qk_norm_type=qk_norm_type,
|
qk_norm_type=qk_norm_type,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
attn_mode=self.attn_mode,
|
|
||||||
)
|
)
|
||||||
for _ in range(mm_double_blocks_depth)
|
for _ in range(mm_double_blocks_depth)
|
||||||
]
|
]
|
||||||
@@ -126,7 +127,6 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
mlp_act_type=mlp_act_type,
|
mlp_act_type=mlp_act_type,
|
||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
qk_norm_type=qk_norm_type,
|
qk_norm_type=qk_norm_type,
|
||||||
attn_mode=self.attn_mode,
|
|
||||||
)
|
)
|
||||||
for _ in range(mm_single_blocks_depth)
|
for _ in range(mm_single_blocks_depth)
|
||||||
]
|
]
|
||||||
@@ -339,22 +339,21 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
# MeanFlow and guidance embedding not used in this configuration
|
# MeanFlow and guidance embedding not used in this configuration
|
||||||
|
|
||||||
# Process text tokens through refinement layers
|
# Process text tokens through refinement layers
|
||||||
txt_lens = text_mask.to(torch.bool).sum(dim=1).tolist()
|
txt_attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, 0, text_mask)
|
||||||
txt = self.txt_in(txt, t, txt_lens)
|
txt = self.txt_in(txt, t, txt_attn_params)
|
||||||
|
|
||||||
# Integrate character-level ByT5 features with word-level tokens
|
# Integrate character-level ByT5 features with word-level tokens
|
||||||
# Use variable length sequences with sequence lengths
|
# Use variable length sequences with sequence lengths
|
||||||
byt5_txt = self.byt5_in(byt5_text_states)
|
byt5_txt = self.byt5_in(byt5_text_states)
|
||||||
txt, _, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
|
txt, text_mask, txt_lens = self.reorder_txt_token(byt5_txt, txt, byt5_text_mask, text_mask)
|
||||||
|
|
||||||
# Trim sequences to maximum length in the batch
|
# Trim sequences to maximum length in the batch
|
||||||
img_seq_len = img.shape[1]
|
img_seq_len = img.shape[1]
|
||||||
# print(f"img_seq_len: {img_seq_len}, txt_lens: {txt_lens}")
|
|
||||||
seq_lens = [img_seq_len + l for l in txt_lens]
|
|
||||||
max_txt_len = max(txt_lens)
|
max_txt_len = max(txt_lens)
|
||||||
# print(f"max_txt_len: {max_txt_len}, seq_lens: {seq_lens}, txt.shape: {txt.shape}")
|
|
||||||
txt = txt[:, :max_txt_len, :]
|
txt = txt[:, :max_txt_len, :]
|
||||||
txt_seq_len = txt.shape[1]
|
text_mask = text_mask[:, :max_txt_len]
|
||||||
|
|
||||||
|
attn_params = AttentionParams.create_attention_params_from_mask(self.attn_mode, self.split_attn, img_seq_len, text_mask)
|
||||||
|
|
||||||
input_device = img.device
|
input_device = img.device
|
||||||
|
|
||||||
@@ -362,7 +361,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
for index, block in enumerate(self.double_blocks):
|
for index, block in enumerate(self.double_blocks):
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.offloader_double.wait_for_block(index)
|
self.offloader_double.wait_for_block(index)
|
||||||
img, txt = block(img, txt, vec, freqs_cis, seq_lens)
|
img, txt = block(img, txt, vec, freqs_cis, attn_params)
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.offloader_double.submit_move_blocks(self.double_blocks, index)
|
self.offloader_double.submit_move_blocks(self.double_blocks, index)
|
||||||
|
|
||||||
@@ -373,7 +372,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
for index, block in enumerate(self.single_blocks):
|
for index, block in enumerate(self.single_blocks):
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.offloader_single.wait_for_block(index)
|
self.offloader_single.wait_for_block(index)
|
||||||
x = block(x, vec, txt_seq_len, freqs_cis, seq_lens)
|
x = block(x, vec, freqs_cis, attn_params)
|
||||||
if self.blocks_to_swap:
|
if self.blocks_to_swap:
|
||||||
self.offloader_single.submit_move_blocks(self.single_blocks, index)
|
self.offloader_single.submit_move_blocks(self.single_blocks, index)
|
||||||
|
|
||||||
@@ -417,7 +416,7 @@ class HYImageDiffusionTransformer(nn.Module):
|
|||||||
|
|
||||||
def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer:
|
def create_model(attn_mode: str, split_attn: bool, dtype: Optional[torch.dtype]) -> HYImageDiffusionTransformer:
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = HYImageDiffusionTransformer(attn_mode=attn_mode)
|
model = HYImageDiffusionTransformer(attn_mode=attn_mode, split_attn=split_attn)
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
model.to(dtype)
|
model.to(dtype)
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch.nn as nn
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
from library import custom_offloading_utils
|
from library import custom_offloading_utils
|
||||||
from library.attention import attention
|
from library.attention import AttentionParams, attention
|
||||||
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
|
from library.hunyuan_image_utils import timestep_embedding, apply_rotary_emb, _to_tuple, apply_gate, modulate
|
||||||
from library.attention import attention
|
from library.attention import attention
|
||||||
|
|
||||||
@@ -213,7 +213,6 @@ class IndividualTokenRefinerBlock(nn.Module):
|
|||||||
qk_norm: QK normalization flag (must be False).
|
qk_norm: QK normalization flag (must be False).
|
||||||
qk_norm_type: QK normalization type (only "layer" supported).
|
qk_norm_type: QK normalization type (only "layer" supported).
|
||||||
qkv_bias: Use bias in QKV projections.
|
qkv_bias: Use bias in QKV projections.
|
||||||
attn_mode: Attention implementation mode.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -226,15 +225,12 @@ class IndividualTokenRefinerBlock(nn.Module):
|
|||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
qk_norm_type: str = "layer",
|
qk_norm_type: str = "layer",
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
attn_mode: str = "torch",
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
|
assert qk_norm_type == "layer", "Only layer normalization supported for QK norm."
|
||||||
assert act_type == "silu", "Only SiLU activation supported."
|
assert act_type == "silu", "Only SiLU activation supported."
|
||||||
assert not qk_norm, "QK normalization must be disabled."
|
assert not qk_norm, "QK normalization must be disabled."
|
||||||
|
|
||||||
self.attn_mode = attn_mode
|
|
||||||
|
|
||||||
self.heads_num = heads_num
|
self.heads_num = heads_num
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||||
|
|
||||||
@@ -253,19 +249,14 @@ class IndividualTokenRefinerBlock(nn.Module):
|
|||||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: torch.Tensor, c: torch.Tensor, attn_params: AttentionParams) -> torch.Tensor:
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
c: torch.Tensor, # Combined timestep and context conditioning
|
|
||||||
txt_lens: list[int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Apply self-attention and MLP with adaptive conditioning.
|
Apply self-attention and MLP with adaptive conditioning.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input token embeddings [B, L, C].
|
x: Input token embeddings [B, L, C].
|
||||||
c: Combined conditioning vector [B, C].
|
c: Combined conditioning vector [B, C].
|
||||||
txt_lens: Valid sequence lengths for each batch element.
|
attn_params: Attention parameters including sequence lengths.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Refined token embeddings [B, L, C].
|
Refined token embeddings [B, L, C].
|
||||||
@@ -273,10 +264,14 @@ class IndividualTokenRefinerBlock(nn.Module):
|
|||||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
norm_x = self.norm1(x)
|
norm_x = self.norm1(x)
|
||||||
qkv = self.self_attn_qkv(norm_x)
|
qkv = self.self_attn_qkv(norm_x)
|
||||||
|
del norm_x
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||||
|
del qkv
|
||||||
q = self.self_attn_q_norm(q).to(v)
|
q = self.self_attn_q_norm(q).to(v)
|
||||||
k = self.self_attn_k_norm(k).to(v)
|
k = self.self_attn_k_norm(k).to(v)
|
||||||
attn = attention(q, k, v, seq_lens=txt_lens, attn_mode=self.attn_mode)
|
qkv = [q, k, v]
|
||||||
|
del q, k, v
|
||||||
|
attn = attention(qkv, attn_params=attn_params)
|
||||||
|
|
||||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||||
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||||
@@ -299,7 +294,6 @@ class IndividualTokenRefiner(nn.Module):
|
|||||||
qk_norm: QK normalization flag.
|
qk_norm: QK normalization flag.
|
||||||
qk_norm_type: QK normalization type.
|
qk_norm_type: QK normalization type.
|
||||||
qkv_bias: Use bias in QKV projections.
|
qkv_bias: Use bias in QKV projections.
|
||||||
attn_mode: Attention implementation mode.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -313,7 +307,6 @@ class IndividualTokenRefiner(nn.Module):
|
|||||||
qk_norm: bool = False,
|
qk_norm: bool = False,
|
||||||
qk_norm_type: str = "layer",
|
qk_norm_type: str = "layer",
|
||||||
qkv_bias: bool = True,
|
qkv_bias: bool = True,
|
||||||
attn_mode: str = "torch",
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.blocks = nn.ModuleList(
|
self.blocks = nn.ModuleList(
|
||||||
@@ -327,26 +320,25 @@ class IndividualTokenRefiner(nn.Module):
|
|||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
qk_norm_type=qk_norm_type,
|
qk_norm_type=qk_norm_type,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
attn_mode=attn_mode,
|
|
||||||
)
|
)
|
||||||
for _ in range(depth)
|
for _ in range(depth)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, c: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, c: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Apply sequential token refinement.
|
Apply sequential token refinement.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input token embeddings [B, L, C].
|
x: Input token embeddings [B, L, C].
|
||||||
c: Combined conditioning vector [B, C].
|
c: Combined conditioning vector [B, C].
|
||||||
txt_lens: Valid sequence lengths for each batch element.
|
attn_params: Attention parameters including sequence lengths.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Refined token embeddings [B, L, C].
|
Refined token embeddings [B, L, C].
|
||||||
"""
|
"""
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, c, txt_lens)
|
x = block(x, c, attn_params)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -362,10 +354,9 @@ class SingleTokenRefiner(nn.Module):
|
|||||||
hidden_size: Transformer hidden dimension.
|
hidden_size: Transformer hidden dimension.
|
||||||
heads_num: Number of attention heads.
|
heads_num: Number of attention heads.
|
||||||
depth: Number of refinement blocks.
|
depth: Number of refinement blocks.
|
||||||
attn_mode: Attention implementation mode.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int, attn_mode: str = "torch"):
|
def __init__(self, in_channels: int, hidden_size: int, heads_num: int, depth: int):
|
||||||
# Fixed architecture parameters for HunyuanImage-2.1
|
# Fixed architecture parameters for HunyuanImage-2.1
|
||||||
mlp_drop_rate: float = 0.0 # No MLP dropout
|
mlp_drop_rate: float = 0.0 # No MLP dropout
|
||||||
act_type: str = "silu" # SiLU activation
|
act_type: str = "silu" # SiLU activation
|
||||||
@@ -389,17 +380,16 @@ class SingleTokenRefiner(nn.Module):
|
|||||||
qk_norm=qk_norm,
|
qk_norm=qk_norm,
|
||||||
qk_norm_type=qk_norm_type,
|
qk_norm_type=qk_norm_type,
|
||||||
qkv_bias=qkv_bias,
|
qkv_bias=qkv_bias,
|
||||||
attn_mode=attn_mode,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, t: torch.LongTensor, txt_lens: list[int]) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, t: torch.LongTensor, attn_params: AttentionParams) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Refine text embeddings with timestep conditioning.
|
Refine text embeddings with timestep conditioning.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input text embeddings [B, L, in_channels].
|
x: Input text embeddings [B, L, in_channels].
|
||||||
t: Diffusion timestep [B].
|
t: Diffusion timestep [B].
|
||||||
txt_lens: Valid sequence lengths for each batch element.
|
attn_params: Attention parameters including sequence lengths.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Refined embeddings [B, L, hidden_size].
|
Refined embeddings [B, L, hidden_size].
|
||||||
@@ -407,13 +397,14 @@ class SingleTokenRefiner(nn.Module):
|
|||||||
timestep_aware_representations = self.t_embedder(t)
|
timestep_aware_representations = self.t_embedder(t)
|
||||||
|
|
||||||
# Compute context-aware representations by averaging valid tokens
|
# Compute context-aware representations by averaging valid tokens
|
||||||
|
txt_lens = attn_params.seqlens # img_len is not used for SingleTokenRefiner
|
||||||
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
|
context_aware_representations = torch.stack([x[i, : txt_lens[i]].mean(dim=0) for i in range(x.shape[0])], dim=0) # [B, C]
|
||||||
|
|
||||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||||
c = timestep_aware_representations + context_aware_representations
|
c = timestep_aware_representations + context_aware_representations
|
||||||
del timestep_aware_representations, context_aware_representations
|
del timestep_aware_representations, context_aware_representations
|
||||||
x = self.input_embedder(x)
|
x = self.input_embedder(x)
|
||||||
x = self.individual_token_refiner(x, c, txt_lens)
|
x = self.individual_token_refiner(x, c, attn_params)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -564,7 +555,6 @@ class MMDoubleStreamBlock(nn.Module):
|
|||||||
qk_norm: QK normalization flag (must be True).
|
qk_norm: QK normalization flag (must be True).
|
||||||
qk_norm_type: QK normalization type (only "rms" supported).
|
qk_norm_type: QK normalization type (only "rms" supported).
|
||||||
qkv_bias: Use bias in QKV projections.
|
qkv_bias: Use bias in QKV projections.
|
||||||
attn_mode: Attention implementation mode.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -576,7 +566,6 @@ class MMDoubleStreamBlock(nn.Module):
|
|||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
qk_norm_type: str = "rms",
|
qk_norm_type: str = "rms",
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
attn_mode: str = "torch",
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -584,7 +573,6 @@ class MMDoubleStreamBlock(nn.Module):
|
|||||||
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
||||||
assert qk_norm, "QK normalization must be enabled."
|
assert qk_norm, "QK normalization must be enabled."
|
||||||
|
|
||||||
self.attn_mode = attn_mode
|
|
||||||
self.heads_num = heads_num
|
self.heads_num = heads_num
|
||||||
head_dim = hidden_size // heads_num
|
head_dim = hidden_size // heads_num
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||||
@@ -626,7 +614,7 @@ class MMDoubleStreamBlock(nn.Module):
|
|||||||
self.cpu_offload_checkpointing = False
|
self.cpu_offload_checkpointing = False
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
|
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Extract modulation parameters for image and text streams
|
# Extract modulation parameters for image and text streams
|
||||||
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
||||||
@@ -687,7 +675,7 @@ class MMDoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
qkv = [q, k, v]
|
qkv = [q, k, v]
|
||||||
del q, k, v
|
del q, k, v
|
||||||
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
|
attn = attention(qkv, attn_params=attn_params)
|
||||||
del qkv
|
del qkv
|
||||||
|
|
||||||
# Split attention outputs back to separate streams
|
# Split attention outputs back to separate streams
|
||||||
@@ -719,16 +707,16 @@ class MMDoubleStreamBlock(nn.Module):
|
|||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, seq_lens: list[int] = None
|
self, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, freqs_cis: tuple = None, attn_params: AttentionParams = None
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
forward_fn = self._forward
|
forward_fn = self._forward
|
||||||
if self.cpu_offload_checkpointing:
|
if self.cpu_offload_checkpointing:
|
||||||
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
|
forward_fn = custom_offloading_utils.cpu_offload_wrapper(forward_fn, self.img_attn_qkv.weight.device)
|
||||||
|
|
||||||
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, seq_lens, use_reentrant=False)
|
return torch.utils.checkpoint.checkpoint(forward_fn, img, txt, vec, freqs_cis, attn_params, use_reentrant=False)
|
||||||
else:
|
else:
|
||||||
return self._forward(img, txt, vec, freqs_cis, seq_lens)
|
return self._forward(img, txt, vec, freqs_cis, attn_params)
|
||||||
|
|
||||||
|
|
||||||
class MMSingleStreamBlock(nn.Module):
|
class MMSingleStreamBlock(nn.Module):
|
||||||
@@ -746,7 +734,6 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
qk_norm: QK normalization flag (must be True).
|
qk_norm: QK normalization flag (must be True).
|
||||||
qk_norm_type: QK normalization type (only "rms" supported).
|
qk_norm_type: QK normalization type (only "rms" supported).
|
||||||
qk_scale: Attention scaling factor (computed automatically if None).
|
qk_scale: Attention scaling factor (computed automatically if None).
|
||||||
attn_mode: Attention implementation mode.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -758,7 +745,6 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
qk_norm_type: str = "rms",
|
qk_norm_type: str = "rms",
|
||||||
qk_scale: float = None,
|
qk_scale: float = None,
|
||||||
attn_mode: str = "torch",
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -766,7 +752,6 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
assert qk_norm_type == "rms", "Only RMS normalization supported."
|
||||||
assert qk_norm, "QK normalization must be enabled."
|
assert qk_norm, "QK normalization must be enabled."
|
||||||
|
|
||||||
self.attn_mode = attn_mode
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.heads_num = heads_num
|
self.heads_num = heads_num
|
||||||
head_dim = hidden_size // heads_num
|
head_dim = hidden_size // heads_num
|
||||||
@@ -805,9 +790,8 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
vec: torch.Tensor,
|
vec: torch.Tensor,
|
||||||
txt_len: int,
|
|
||||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
seq_lens: list[int] = None,
|
attn_params: AttentionParams = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Extract modulation parameters
|
# Extract modulation parameters
|
||||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||||
@@ -828,12 +812,10 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
k = self.k_norm(k).to(v)
|
k = self.k_norm(k).to(v)
|
||||||
|
|
||||||
# Separate image and text tokens
|
# Separate image and text tokens
|
||||||
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
img_q, txt_q = q[:, : attn_params.img_len, :, :], q[:, attn_params.img_len :, :, :]
|
||||||
del q
|
del q
|
||||||
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
img_k, txt_k = k[:, : attn_params.img_len, :, :], k[:, attn_params.img_len :, :, :]
|
||||||
del k
|
del k
|
||||||
# img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
|
|
||||||
# del v
|
|
||||||
|
|
||||||
# Apply rotary position embeddings only to image tokens
|
# Apply rotary position embeddings only to image tokens
|
||||||
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
||||||
@@ -848,7 +830,7 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
# del img_v, txt_v
|
# del img_v, txt_v
|
||||||
qkv = [q, k, v]
|
qkv = [q, k, v]
|
||||||
del q, k, v
|
del q, k, v
|
||||||
attn = attention(qkv, seq_lens=seq_lens, attn_mode=self.attn_mode)
|
attn = attention(qkv, attn_params=attn_params)
|
||||||
del qkv
|
del qkv
|
||||||
|
|
||||||
# Combine attention and MLP outputs, apply gating
|
# Combine attention and MLP outputs, apply gating
|
||||||
@@ -865,18 +847,17 @@ class MMSingleStreamBlock(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
vec: torch.Tensor,
|
vec: torch.Tensor,
|
||||||
txt_len: int,
|
|
||||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
seq_lens: list[int] = None,
|
attn_params: AttentionParams = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
forward_fn = self._forward
|
forward_fn = self._forward
|
||||||
if self.cpu_offload_checkpointing:
|
if self.cpu_offload_checkpointing:
|
||||||
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
|
forward_fn = custom_offloading_utils.create_cpu_offloading_wrapper(forward_fn, self.linear1.weight.device)
|
||||||
|
|
||||||
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, txt_len, freqs_cis, seq_lens, use_reentrant=False)
|
return torch.utils.checkpoint.checkpoint(forward_fn, x, vec, freqs_cis, attn_params, use_reentrant=False)
|
||||||
else:
|
else:
|
||||||
return self._forward(x, vec, txt_len, freqs_cis, seq_lens)
|
return self._forward(x, vec, freqs_cis, attn_params)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|||||||
Reference in New Issue
Block a user