From 7a86668ca477087432b0e97bfba958c69065dc5d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 24 Mar 2025 05:13:37 -0400 Subject: [PATCH] Add proportional attention and ntk factor to flux model --- library/flux_models.py | 207 +++++++++++++++++++++++++++++++++-------- 1 file changed, 169 insertions(+), 38 deletions(-) diff --git a/library/flux_models.py b/library/flux_models.py index c9c6101b..bbc05c3d 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -446,10 +446,17 @@ configs = { # region math -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Optional[Tensor] = None) -> Tensor: +def attention( + q: Tensor, + k: Tensor, + v: Tensor, + pe: Tensor, + attn_mask: Optional[Tensor] = None, + attention_scale=None, +) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=attention_scale) x = rearrange(x, "B H L D -> B L (H D)") return x @@ -511,10 +518,13 @@ class EmbedND(nn.Module): self.theta = theta self.axes_dim = axes_dim - def forward(self, ids: Tensor) -> Tensor: + def forward(self, ids: Tensor, scale=1.0) -> Tensor: + """ + scale: NTK factor for increasing the embedding space for large images + """ n_axes = ids.shape[-1] emb = torch.cat( - [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + [rope(ids[..., i], self.axes_dim[i], self.theta * scale) for i in range(n_axes)], dim=-3, ) @@ -608,17 +618,23 @@ class SelfAttention(nn.Module): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads + self.head_dim = head_dim self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) # this is not called from DoubleStreamBlock/SingleStreamBlock because they uses attention function directly - def forward(self, x: Tensor, pe: Tensor) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, proportional_attention=False) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) - x = attention(q, k, v, pe=pe) + if proportional_attention: + train_seq_len = 512 + 64 * 64 + attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / self.head_dim) + else: + attention_scale = math.sqrt(1 / self.head_dim) + x = attention(q, k, v, pe=pe, attention_scale=attention_scale) x = self.proj(x) return x @@ -687,7 +703,13 @@ class DoubleStreamBlock(nn.Module): self.cpu_offload_checkpointing = False def _forward( - self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + txt_attention_mask: Optional[Tensor] = None, + proportional_attention=False, ) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) @@ -717,13 +739,27 @@ class DoubleStreamBlock(nn.Module): # F.scaled_dot_product_attention expects attn_mask to be bool for binary mask attn_mask = txt_attention_mask.to(torch.bool) # b, seq_len attn_mask = torch.cat( - (attn_mask, torch.ones(attn_mask.shape[0], img.shape[1], device=attn_mask.device, dtype=torch.bool)), dim=1 + ( + attn_mask, + torch.ones( + attn_mask.shape[0], + img.shape[1], + device=attn_mask.device, + dtype=torch.bool, + ), + ), + dim=1, ) # b, seq_len + img_len # broadcast attn_mask to all heads attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) - attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + if proportional_attention: + train_seq_len = 512 + 64 * 64 + attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / (self.hidden_size // self.num_heads)) + else: + attention_scale = math.sqrt(1 / (self.hidden_size // self.num_heads)) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask, attention_scale=attention_scale) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img blocks @@ -736,11 +772,26 @@ class DoubleStreamBlock(nn.Module): return img, txt def forward( - self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None + self, + img: Tensor, + txt: Tensor, + vec: Tensor, + pe: Tensor, + txt_attention_mask: Optional[Tensor] = None, + proportional_attention=False, ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, img, txt, vec, pe, txt_attention_mask, use_reentrant=False) + return checkpoint( + self._forward, + img, + txt, + vec, + pe, + txt_attention_mask, + proportional_attention=proportional_attention, + use_reentrant=False, + ) # cpu offload checkpointing def create_custom_forward(func): @@ -752,7 +803,14 @@ class DoubleStreamBlock(nn.Module): return custom_forward return torch.utils.checkpoint.checkpoint( - create_custom_forward(self._forward), img, txt, vec, pe, txt_attention_mask, use_reentrant=False + create_custom_forward(self._forward), + img, + txt, + vec, + pe, + txt_attention_mask, + proportional_attention=proportional_attention, + use_reentrant=False, ) else: @@ -803,7 +861,14 @@ class SingleStreamBlock(nn.Module): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False - def _forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + def _forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + txt_attention_mask: Optional[Tensor] = None, + proportional_attention=False, + ) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) @@ -820,7 +885,10 @@ class SingleStreamBlock(nn.Module): ( attn_mask, torch.ones( - attn_mask.shape[0], x.shape[1] - txt_attention_mask.shape[1], device=attn_mask.device, dtype=torch.bool + attn_mask.shape[0], + x.shape[1] - txt_attention_mask.shape[1], + device=attn_mask.device, + dtype=torch.bool, ), ), dim=1, @@ -829,17 +897,29 @@ class SingleStreamBlock(nn.Module): # broadcast attn_mask to all heads attn_mask = attn_mask[:, None, None, :].expand(-1, q.shape[1], q.shape[2], -1) + if proportional_attention: + train_seq_len = 512 + 64 * 64 + attention_scale = math.sqrt(math.log(k.size(2), train_seq_len) / (self.hidden_size // self.num_heads)) + else: + attention_scale = math.sqrt(1 / (self.hidden_size // self.num_heads)) # compute attention - attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask, attention_scale=attention_scale) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, txt_attention_mask: Optional[Tensor] = None) -> Tensor: + def forward( + self, + x: Tensor, + vec: Tensor, + pe: Tensor, + txt_attention_mask: Optional[Tensor] = None, + proportional_attention=False, + ) -> Tensor: if self.training and self.gradient_checkpointing: if not self.cpu_offload_checkpointing: - return checkpoint(self._forward, x, vec, pe, txt_attention_mask, use_reentrant=False) + return checkpoint(self._forward, x, vec, pe, txt_attention_mask, proportional_attention, use_reentrant=False) # cpu offload checkpointing @@ -852,7 +932,13 @@ class SingleStreamBlock(nn.Module): return custom_forward return torch.utils.checkpoint.checkpoint( - create_custom_forward(self._forward), x, vec, pe, txt_attention_mask, use_reentrant=False + create_custom_forward(self._forward), + x, + vec, + pe, + txt_attention_mask, + proportional_attention=proportional_attention, + use_reentrant=False, ) else: return self._forward(x, vec, pe, txt_attention_mask) @@ -913,10 +999,7 @@ class Flux(nn.Module): ) self.single_blocks = nn.ModuleList( - [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(params.depth_single_blocks) - ] + [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks)] ) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) @@ -977,10 +1060,16 @@ class Flux(nn.Module): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, + self.num_double_blocks, + double_blocks_to_swap, + device, # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, + self.num_single_blocks, + single_blocks_to_swap, + device, # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1019,6 +1108,7 @@ class Flux(nn.Module): guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, proportional_attention=None, + ntk_factor=1.0, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -1034,7 +1124,7 @@ class Flux(nn.Module): txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + pe = self.pe_embedder(ids, ntk_factor) if block_controlnet_hidden_states is not None: controlnet_depth = len(block_controlnet_hidden_states) if block_controlnet_single_hidden_states is not None: @@ -1042,20 +1132,40 @@ class Flux(nn.Module): if not self.blocks_to_swap: for block_idx, block in enumerate(self.double_blocks): - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention) + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + txt_attention_mask=txt_attention_mask, + proportional_attention=proportional_attention, + ) if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] img = torch.cat((txt, img), 1) for block_idx, block in enumerate(self.single_blocks): - img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention) + img = block( + img, + vec=vec, + pe=pe, + txt_attention_mask=txt_attention_mask, + proportional_attention=proportional_attention, + ) if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] else: for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention) + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + txt_attention_mask=txt_attention_mask, + proportional_attention=proportional_attention, + ) if block_controlnet_hidden_states is not None and controlnet_depth > 0: img = img + block_controlnet_hidden_states[block_idx % controlnet_depth] @@ -1066,7 +1176,13 @@ class Flux(nn.Module): for block_idx, block in enumerate(self.single_blocks): self.offloader_single.wait_for_block(block_idx) - img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask, proportional_attention=proportional_attention) + img = block( + img, + vec=vec, + pe=pe, + txt_attention_mask=txt_attention_mask, + proportional_attention=proportional_attention, + ) if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0: img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth] @@ -1127,10 +1243,7 @@ class ControlNetFlux(nn.Module): ) self.single_blocks = nn.ModuleList( - [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) - for _ in range(controlnet_single_depth) - ] + [SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(controlnet_single_depth)] ) self.gradient_checkpointing = False @@ -1170,7 +1283,7 @@ class ControlNetFlux(nn.Module): nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1, stride=2), nn.SiLU(), - zero_module(nn.Conv2d(16, 16, 3, padding=1)) + zero_module(nn.Conv2d(16, 16, 3, padding=1)), ) @property @@ -1220,10 +1333,16 @@ class ControlNetFlux(nn.Module): ) self.offloader_double = custom_offloading_utils.ModelOffloader( - self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True + self.double_blocks, + self.num_double_blocks, + double_blocks_to_swap, + device, # , debug=True ) self.offloader_single = custom_offloading_utils.ModelOffloader( - self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True + self.single_blocks, + self.num_single_blocks, + single_blocks_to_swap, + device, # , debug=True ) print( f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}." @@ -1285,7 +1404,13 @@ class ControlNetFlux(nn.Module): block_single_samples = () if not self.blocks_to_swap: for block in self.double_blocks: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + txt_attention_mask=txt_attention_mask, + ) block_samples = block_samples + (img,) img = torch.cat((txt, img), 1) @@ -1296,7 +1421,13 @@ class ControlNetFlux(nn.Module): for block_idx, block in enumerate(self.double_blocks): self.offloader_double.wait_for_block(block_idx) - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask) + img, txt = block( + img=img, + txt=txt, + vec=vec, + pe=pe, + txt_attention_mask=txt_attention_mask, + ) block_samples = block_samples + (img,) self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)