From c4958b5dca0102b3f18fa2d2a383f177d508f872 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 16:30:43 +0900 Subject: [PATCH] feat: change img/txt order for attention and single blocks --- library/chroma_models.py | 75 +++++++++++++++------------------------- 1 file changed, 28 insertions(+), 47 deletions(-) diff --git a/library/chroma_models.py b/library/chroma_models.py index 06822a37..1b62f20f 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -236,7 +236,8 @@ class DoubleStreamBlock(nn.Module): txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention: we split the batch into each element - max_txt_len = txt_q.shape[-2] # max 512 + max_txt_len = torch.max(txt_seq_len).item() + img_len = img_q.shape[-2] # max 64 txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0)) txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0)) @@ -246,35 +247,25 @@ class DoubleStreamBlock(nn.Module): txt_attn = [] img_attn = [] for i in range(txt.shape[0]): - print(i) - print(f"len(txt_q) = {len(txt_q)}, len(img_q) = {len(img_q)}, txt_seq_len.shape = {txt_seq_len.shape}") - print(f"txt_seq_len[i] = {txt_seq_len[i]}, txt_q.shape = {txt_q[i].shape}, img_q.shape = {img_q[i].shape}") - txt_q_i = txt_q[i][:, :, : txt_seq_len[i]] + txt_q[i] = txt_q[i][:, :, : txt_seq_len[i]] + q = torch.cat((img_q[i], txt_q[i]), dim=2) txt_q[i] = None - img_q_i = img_q[i] img_q[i] = None - q = torch.cat((txt_q_i, img_q_i), dim=2) - del txt_q_i, img_q_i - txt_k_i = txt_k[i][:, :, : txt_seq_len[i]] + txt_k[i] = txt_k[i][:, :, : txt_seq_len[i]] + k = torch.cat((img_k[i], txt_k[i]), dim=2) txt_k[i] = None - img_k_i = img_k[i] img_k[i] = None - k = torch.cat((txt_k_i, img_k_i), dim=2) - del txt_k_i, img_k_i - txt_v_i = txt_v[i][:, :, : txt_seq_len[i]] + txt_v[i] = txt_v[i][:, :, : txt_seq_len[i]] + v = torch.cat((img_v[i], txt_v[i]), dim=2) txt_v[i] = None - img_v_i = img_v[i] img_v[i] = None - v = torch.cat((txt_v_i, img_v_i), dim=2) - del txt_v_i, img_v_i - attn = attention(q, k, v, pe=pe[i], attn_mask=None) # (1, L, D) - print(f"attn.shape = {attn.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}") + attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D) + img_attn_i = attn[:, :img_len, :] txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) - txt_attn_i[:, : txt_seq_len[i], :] = attn[:, : txt_seq_len[i], :] - img_attn_i = attn[:, txt_seq_len[i] :, :] + txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :] txt_attn.append(txt_attn_i) img_attn.append(img_attn_i) @@ -377,9 +368,7 @@ class SingleStreamBlock(nn.Module): def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - def _forward( - self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int - ) -> Tensor: + def _forward(self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor: mod = distill_vec # replaced with compiled fn # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift @@ -393,25 +382,23 @@ class SingleStreamBlock(nn.Module): # attn = attention(q, k, v, pe=pe, attn_mask=mask) # compute attention: we split the batch into each element + max_txt_len = torch.max(txt_seq_len).item() + img_len = q.shape[-2] - max_txt_len q = list(torch.chunk(q, q.shape[0], dim=0)) k = list(torch.chunk(k, k.shape[0], dim=0)) v = list(torch.chunk(v, v.shape[0], dim=0)) attn = [] for i in range(x.size(0)): - q_i = torch.cat((q[i][:, :, : txt_seq_len[i]], q[i][:, :, max_txt_len:]), dim=2) + q[i] = q[i][:, :, : img_len + txt_seq_len[i]] + k[i] = k[i][:, :, : img_len + txt_seq_len[i]] + v[i] = v[i][:, :, : img_len + txt_seq_len[i]] + attn_trimmed = attention(q[i], k[i], v[i], pe=pe[i : i + 1, :, : img_len + txt_seq_len[i]], attn_mask=None) q[i] = None - k_i = torch.cat((k[i][:, :, : txt_seq_len[i]], k[i][:, :, max_txt_len:]), dim=2) k[i] = None - v_i = torch.cat((v[i][:, :, : txt_seq_len[i]], v[i][:, :, max_txt_len:]), dim=2) v[i] = None - attn_trimmed = attention(q_i, k_i, v_i, pe=pe[i], attn_mask=None) - print( - f"attn_trimmed.shape = {attn_trimmed.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}, x.shape = {x.shape}" - ) attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) - attn_i[:, : txt_seq_len[i], :] = attn_trimmed[:, : txt_seq_len[i], :] - attn_i[:, max_txt_len:, :] = attn_trimmed[:, txt_seq_len[i] :, :] + attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed attn.append(attn_i) attn = torch.cat(attn, dim=0) @@ -422,11 +409,11 @@ class SingleStreamBlock(nn.Module): # return x + mod.gate * output return self.modulation_gate_fn(x, mod.gate, output) - def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor: if self.training and self.gradient_checkpointing: - return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, max_txt_len, use_reentrant=False) + return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, use_reentrant=False) else: - return self._forward(x, pe, distill_vec, txt_seq_len, max_txt_len) + return self._forward(x, pe, distill_vec, txt_seq_len) class LastLayer(nn.Module): @@ -677,9 +664,6 @@ class Chroma(Flux): mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) - ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 - # calculate text length for each batch instead of masking txt_emb_len = txt.shape[1] txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, ) @@ -689,12 +673,9 @@ class Chroma(Flux): # trim txt embedding to the text length txt = txt[:, :max_txt_len, :] - # split positional encoding into each element of the batch, and trim masked tokens - print(f"pe shape = {pe.shape} dtype = {pe.dtype}, txt_seq_len = {txt_seq_len}") - pe = list(torch.chunk(pe, pe.shape[0], dim=0)) - for i in range(len(pe)): - # trim positional encoding to the text length - pe[i] = torch.cat([pe[i][:, :, : txt_seq_len[i]], pe[i][:, :, txt_emb_len:]], dim=2) + # create positional encoding for the text and image + ids = torch.cat((img_ids, txt_ids[:, :max_txt_len]), dim=1) # reverse order of ids for faster attention + pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 for i, block in enumerate(self.double_blocks): if self.blocks_to_swap: @@ -710,19 +691,19 @@ class Chroma(Flux): if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, i) - img = torch.cat((txt, img), 1) + img = torch.cat((img, txt), 1) for i, block in enumerate(self.single_blocks): if self.blocks_to_swap: self.offloader_single.wait_for_block(i) single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len, max_txt_len=max_txt_len) + img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len) if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, i) - img = img[:, txt.shape[1] :, ...] + img = img[:, :-max_txt_len, ...] final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels) return img