Add proportional attention and ntk factor to flux model

This commit is contained in:
rockerBOO
2025-03-24 05:13:37 -04:00
parent b46953d5a2
commit 7a86668ca4

View File

@@ -446,10 +446,17 @@ configs = {
# region math
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
def attention(
q: Tensor,
k: Tensor,
v: Tensor,
pe: Tensor,
attn_mask: Optional[Tensor] = None,
attention_scale=None,
) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=attention_scale)
x = rearrange(x, "B H L D -> B L (H D)")
return x
@@ -511,10 +518,13 @@ class EmbedND(nn.Module):
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
def forward(self, ids: Tensor, scale=1.0) -> Tensor:
"""
scale: NTK factor for increasing the embedding space for large images
"""
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
[rope(ids[..., i], self.axes_dim[i], self.theta * scale) for i in range(n_axes)],
dim=-3,
)
@@ -608,17 +618,23 @@ class SelfAttention(nn.Module):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.head_dim = head_dim
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
def forward(self, x: Tensor, pe: Tensor, proportional_attention=False) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
if proportional_attention:
train_seq_len = 512 + 64 * 64
attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / self.head_dim)
else:
attention_scale = math.sqrt(1 / self.head_dim)
x = attention(q, k, v, pe=pe, attention_scale=attention_scale)
x = self.proj(x)
return x
@@ -687,7 +703,13 @@ class DoubleStreamBlock(nn.Module):
self.cpu_offload_checkpointing = False
def _forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
self,
img: Tensor,
txt: Tensor,
vec: Tensor,
pe: Tensor,
txt_attention_mask: Optional[Tensor] = None,
proportional_attention=False,
) -> tuple[Tensor, Tensor]:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@@ -717,13 +739,27 @@ class DoubleStreamBlock(nn.Module):
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
attn_mask = torch.cat(
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
(
attn_mask,
torch.ones(
attn_mask.shape[0],
img.shape[1],
device=attn_mask.device,
dtype=torch.bool,
),
),
dim=1,
) # b, seq_len + img_len
# broadcast attn_mask to all heads
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
if proportional_attention:
train_seq_len = 512 + 64 * 64
attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / (self.hidden_size // self.num_heads))
else:
attention_scale = math.sqrt(1 / (self.hidden_size // self.num_heads))
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask, attention_scale=attention_scale)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img blocks
@@ -736,11 +772,26 @@ class DoubleStreamBlock(nn.Module):
return img, txt
def forward(
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
self,
img: Tensor,
txt: Tensor,
vec: Tensor,
pe: Tensor,
txt_attention_mask: Optional[Tensor] = None,
proportional_attention=False,
) -> tuple[Tensor, Tensor]:
if self.training and self.gradient_checkpointing:
if not self.cpu_offload_checkpointing:
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
return checkpoint(
self._forward,
img,
txt,
vec,
pe,
txt_attention_mask,
proportional_attention=proportional_attention,
use_reentrant=False,
)
# cpu offload checkpointing
def create_custom_forward(func):
@@ -752,7 +803,14 @@ class DoubleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
create_custom_forward(self._forward),
img,
txt,
vec,
pe,
txt_attention_mask,
proportional_attention=proportional_attention,
use_reentrant=False,
)
else:
@@ -803,7 +861,14 @@ class SingleStreamBlock(nn.Module):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
def _forward(
self,
x: Tensor,
vec: Tensor,
pe: Tensor,
txt_attention_mask: Optional[Tensor] = None,
proportional_attention=False,
) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@@ -820,7 +885,10 @@ class SingleStreamBlock(nn.Module):
(
attn_mask,
torch.ones(
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
attn_mask.shape[0],
x.shape[1] - txt_attention_mask.shape[1],
device=attn_mask.device,
dtype=torch.bool,
),
),
dim=1,
@@ -829,17 +897,29 @@ class SingleStreamBlock(nn.Module):
# broadcast attn_mask to all heads
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
if proportional_attention:
train_seq_len = 512 + 64 * 64
attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / (self.hidden_size // self.num_heads))
else:
attention_scale = math.sqrt(1 / (self.hidden_size // self.num_heads))
# compute attention
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask, attention_scale=attention_scale)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
def forward(
self,
x: Tensor,
vec: Tensor,
pe: Tensor,
txt_attention_mask: Optional[Tensor] = None,
proportional_attention=False,
) -> Tensor:
if self.training and self.gradient_checkpointing:
if not self.cpu_offload_checkpointing:
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, proportional_attention, use_reentrant=False)
# cpu offload checkpointing
@@ -852,7 +932,13 @@ class SingleStreamBlock(nn.Module):
return custom_forward
return torch.utils.checkpoint.checkpoint(
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
create_custom_forward(self._forward),
x,
vec,
pe,
txt_attention_mask,
proportional_attention=proportional_attention,
use_reentrant=False,
)
else:
return self._forward(x, vec, pe, txt_attention_mask)
@@ -913,10 +999,7 @@ class Flux(nn.Module):
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
[SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
@@ -977,10 +1060,16 @@ class Flux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks,
self.num_double_blocks,
double_blocks_to_swap,
device, # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks,
self.num_single_blocks,
single_blocks_to_swap,
device, # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1019,6 +1108,7 @@ class Flux(nn.Module):
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
proportional_attention=None,
ntk_factor=1.0,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -1034,7 +1124,7 @@ class Flux(nn.Module):
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
pe = self.pe_embedder(ids, ntk_factor)
if block_controlnet_hidden_states is not None:
controlnet_depth = len(block_controlnet_hidden_states)
if block_controlnet_single_hidden_states is not None:
@@ -1042,20 +1132,40 @@ class Flux(nn.Module):
if not self.blocks_to_swap:
for block_idx, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
img, txt = block(
img=img,
txt=txt,
vec=vec,
pe=pe,
txt_attention_mask=txt_attention_mask,
proportional_attention=proportional_attention,
)
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
img = torch.cat((txt, img), 1)
for block_idx, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
img = block(
img,
vec=vec,
pe=pe,
txt_attention_mask=txt_attention_mask,
proportional_attention=proportional_attention,
)
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
else:
for block_idx, block in enumerate(self.double_blocks):
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
img, txt = block(
img=img,
txt=txt,
vec=vec,
pe=pe,
txt_attention_mask=txt_attention_mask,
proportional_attention=proportional_attention,
)
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
@@ -1066,7 +1176,13 @@ class Flux(nn.Module):
for block_idx, block in enumerate(self.single_blocks):
self.offloader_single.wait_for_block(block_idx)
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
img = block(
img,
vec=vec,
pe=pe,
txt_attention_mask=txt_attention_mask,
proportional_attention=proportional_attention,
)
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
@@ -1127,10 +1243,7 @@ class ControlNetFlux(nn.Module):
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(controlnet_single_depth)
]
[SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(controlnet_single_depth)]
)
self.gradient_checkpointing = False
@@ -1170,7 +1283,7 @@ class ControlNetFlux(nn.Module):
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(nn.Conv2d(16, 16, 3, padding=1))
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
)
@property
@@ -1220,10 +1333,16 @@ class ControlNetFlux(nn.Module):
)
self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks,
self.num_double_blocks,
double_blocks_to_swap,
device, # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks,
self.num_single_blocks,
single_blocks_to_swap,
device, # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1285,7 +1404,13 @@ class ControlNetFlux(nn.Module):
block_single_samples = ()
if not self.blocks_to_swap:
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img, txt = block(
img=img,
txt=txt,
vec=vec,
pe=pe,
txt_attention_mask=txt_attention_mask,
)
block_samples = block_samples + (img,)
img = torch.cat((txt, img), 1)
@@ -1296,7 +1421,13 @@ class ControlNetFlux(nn.Module):
for block_idx, block in enumerate(self.double_blocks):
self.offloader_double.wait_for_block(block_idx)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img, txt = block(
img=img,
txt=txt,
vec=vec,
pe=pe,
txt_attention_mask=txt_attention_mask,
)
block_samples = block_samples + (img,)
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)