mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Add proportional attention and ntk factor to flux model
This commit is contained in:
@@ -446,10 +446,17 @@ configs = {
|
||||
# region math
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor:
|
||||
def attention(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
pe: Tensor,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
attention_scale=None,
|
||||
) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=attention_scale)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
|
||||
return x
|
||||
@@ -511,10 +518,13 @@ class EmbedND(nn.Module):
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: Tensor) -> Tensor:
|
||||
def forward(self, ids: Tensor, scale=1.0) -> Tensor:
|
||||
"""
|
||||
scale: NTK factor for increasing the embedding space for large images
|
||||
"""
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta * scale) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
|
||||
@@ -608,17 +618,23 @@ class SelfAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.norm = QKNorm(head_dim)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
# this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly
|
||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, pe: Tensor, proportional_attention=False) -> Tensor:
|
||||
qkv = self.qkv(x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k = self.norm(q, k, v)
|
||||
x = attention(q, k, v, pe=pe)
|
||||
if proportional_attention:
|
||||
train_seq_len = 512 + 64 * 64
|
||||
attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / self.head_dim)
|
||||
else:
|
||||
attention_scale = math.sqrt(1 / self.head_dim)
|
||||
x = attention(q, k, v, pe=pe, attention_scale=attention_scale)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
@@ -687,7 +703,13 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(
|
||||
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
vec: Tensor,
|
||||
pe: Tensor,
|
||||
txt_attention_mask: Optional[Tensor] = None,
|
||||
proportional_attention=False,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
@@ -717,13 +739,27 @@ class DoubleStreamBlock(nn.Module):
|
||||
# F.scaled_dot_product_attention expects attn_mask to be bool for binary mask
|
||||
attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len
|
||||
attn_mask = torch.cat(
|
||||
(attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1
|
||||
(
|
||||
attn_mask,
|
||||
torch.ones(
|
||||
attn_mask.shape[0],
|
||||
img.shape[1],
|
||||
device=attn_mask.device,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
),
|
||||
dim=1,
|
||||
) # b, seq_len + img_len
|
||||
|
||||
# broadcast attn_mask to all heads
|
||||
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
if proportional_attention:
|
||||
train_seq_len = 512 + 64 * 64
|
||||
attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / (self.hidden_size // self.num_heads))
|
||||
else:
|
||||
attention_scale = math.sqrt(1 / (self.hidden_size // self.num_heads))
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask, attention_scale=attention_scale)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img blocks
|
||||
@@ -736,11 +772,26 @@ class DoubleStreamBlock(nn.Module):
|
||||
return img, txt
|
||||
|
||||
def forward(
|
||||
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None
|
||||
self,
|
||||
img: Tensor,
|
||||
txt: Tensor,
|
||||
vec: Tensor,
|
||||
pe: Tensor,
|
||||
txt_attention_mask: Optional[Tensor] = None,
|
||||
proportional_attention=False,
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||
return checkpoint(
|
||||
self._forward,
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# cpu offload checkpointing
|
||||
|
||||
def create_custom_forward(func):
|
||||
@@ -752,7 +803,14 @@ class DoubleStreamBlock(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
create_custom_forward(self._forward),
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
use_reentrant=False,
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -803,7 +861,14 @@ class SingleStreamBlock(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
self.cpu_offload_checkpointing = False
|
||||
|
||||
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
vec: Tensor,
|
||||
pe: Tensor,
|
||||
txt_attention_mask: Optional[Tensor] = None,
|
||||
proportional_attention=False,
|
||||
) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@@ -820,7 +885,10 @@ class SingleStreamBlock(nn.Module):
|
||||
(
|
||||
attn_mask,
|
||||
torch.ones(
|
||||
attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool
|
||||
attn_mask.shape[0],
|
||||
x.shape[1] - txt_attention_mask.shape[1],
|
||||
device=attn_mask.device,
|
||||
dtype=torch.bool,
|
||||
),
|
||||
),
|
||||
dim=1,
|
||||
@@ -829,17 +897,29 @@ class SingleStreamBlock(nn.Module):
|
||||
# broadcast attn_mask to all heads
|
||||
attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1)
|
||||
|
||||
if proportional_attention:
|
||||
train_seq_len = 512 + 64 * 64
|
||||
attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / (self.hidden_size // self.num_heads))
|
||||
else:
|
||||
attention_scale = math.sqrt(1 / (self.hidden_size // self.num_heads))
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=attn_mask, attention_scale=attention_scale)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
return x + mod.gate * output
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
vec: Tensor,
|
||||
pe: Tensor,
|
||||
txt_attention_mask: Optional[Tensor] = None,
|
||||
proportional_attention=False,
|
||||
) -> Tensor:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if not self.cpu_offload_checkpointing:
|
||||
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False)
|
||||
return checkpoint(self._forward, x, vec, pe, txt_attention_mask, proportional_attention, use_reentrant=False)
|
||||
|
||||
# cpu offload checkpointing
|
||||
|
||||
@@ -852,7 +932,13 @@ class SingleStreamBlock(nn.Module):
|
||||
return custom_forward
|
||||
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False
|
||||
create_custom_forward(self._forward),
|
||||
x,
|
||||
vec,
|
||||
pe,
|
||||
txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
return self._forward(x, vec, pe, txt_attention_mask)
|
||||
@@ -913,10 +999,7 @@ class Flux(nn.Module):
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
[SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
||||
@@ -977,10 +1060,16 @@ class Flux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks,
|
||||
self.num_double_blocks,
|
||||
double_blocks_to_swap,
|
||||
device, # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks,
|
||||
self.num_single_blocks,
|
||||
single_blocks_to_swap,
|
||||
device, # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1019,6 +1108,7 @@ class Flux(nn.Module):
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
proportional_attention=None,
|
||||
ntk_factor=1.0,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -1034,7 +1124,7 @@ class Flux(nn.Module):
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
pe = self.pe_embedder(ids, ntk_factor)
|
||||
if block_controlnet_hidden_states is not None:
|
||||
controlnet_depth = len(block_controlnet_hidden_states)
|
||||
if block_controlnet_single_hidden_states is not None:
|
||||
@@ -1042,20 +1132,40 @@ class Flux(nn.Module):
|
||||
|
||||
if not self.blocks_to_swap:
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
|
||||
img, txt = block(
|
||||
img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
txt_attention_mask=txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
)
|
||||
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
|
||||
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
|
||||
img = block(
|
||||
img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
txt_attention_mask=txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
)
|
||||
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
|
||||
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
||||
else:
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
self.offloader_double.wait_for_block(block_idx)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
|
||||
img, txt = block(
|
||||
img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
txt_attention_mask=txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
)
|
||||
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
|
||||
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
||||
|
||||
@@ -1066,7 +1176,13 @@ class Flux(nn.Module):
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
self.offloader_single.wait_for_block(block_idx)
|
||||
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention)
|
||||
img = block(
|
||||
img,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
txt_attention_mask=txt_attention_mask,
|
||||
proportional_attention=proportional_attention,
|
||||
)
|
||||
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
|
||||
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
||||
|
||||
@@ -1127,10 +1243,7 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(controlnet_single_depth)
|
||||
]
|
||||
[SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(controlnet_single_depth)]
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -1170,7 +1283,7 @@ class ControlNetFlux(nn.Module):
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
nn.SiLU(),
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1))
|
||||
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -1220,10 +1333,16 @@ class ControlNetFlux(nn.Module):
|
||||
)
|
||||
|
||||
self.offloader_double = custom_offloading_utils.ModelOffloader(
|
||||
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
|
||||
self.double_blocks,
|
||||
self.num_double_blocks,
|
||||
double_blocks_to_swap,
|
||||
device, # , debug=True
|
||||
)
|
||||
self.offloader_single = custom_offloading_utils.ModelOffloader(
|
||||
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
|
||||
self.single_blocks,
|
||||
self.num_single_blocks,
|
||||
single_blocks_to_swap,
|
||||
device, # , debug=True
|
||||
)
|
||||
print(
|
||||
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
||||
@@ -1285,7 +1404,13 @@ class ControlNetFlux(nn.Module):
|
||||
block_single_samples = ()
|
||||
if not self.blocks_to_swap:
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img, txt = block(
|
||||
img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
txt_attention_mask=txt_attention_mask,
|
||||
)
|
||||
block_samples = block_samples + (img,)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
@@ -1296,7 +1421,13 @@ class ControlNetFlux(nn.Module):
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
self.offloader_double.wait_for_block(block_idx)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
img, txt = block(
|
||||
img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
txt_attention_mask=txt_attention_mask,
|
||||
)
|
||||
block_samples = block_samples + (img,)
|
||||
|
||||
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
||||
|
||||
Reference in New Issue
Block a user