From 8fd0b12d1f8bcae52cb11f0ccd193d8382b06166 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 16:00:58 +0900 Subject: [PATCH] feat: update DoubleStreamBlock and SingleStreamBlock to handle text sequence lengths instead of mask --- library/chroma_models.py | 232 ++++++++++++++++++++++++++------------- 1 file changed, 154 insertions(+), 78 deletions(-) diff --git a/library/chroma_models.py b/library/chroma_models.py index f725db87..06822a37 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -211,9 +211,9 @@ class DoubleStreamBlock(nn.Module): self, img: Tensor, txt: Tensor, - pe: Tensor, + pe: list[Tensor], distill_vec: list[ModulationOut], - mask: Tensor, + txt_seq_len: Tensor, ) -> tuple[Tensor, Tensor]: (img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec @@ -235,13 +235,58 @@ class DoubleStreamBlock(nn.Module): txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - # run actual attention - q = torch.cat((txt_q, img_q), dim=2) - k = torch.cat((txt_k, img_k), dim=2) - v = torch.cat((txt_v, img_v), dim=2) + # run actual attention: we split the batch into each element + max_txt_len = txt_q.shape[-2] # max 512 + txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors + txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0)) + txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0)) + img_q = list(torch.chunk(img_q, img_q.shape[0], dim=0)) + img_k = list(torch.chunk(img_k, img_k.shape[0], dim=0)) + img_v = list(torch.chunk(img_v, img_v.shape[0], dim=0)) + txt_attn = [] + img_attn = [] + for i in range(txt.shape[0]): + print(i) + print(f"len(txt_q) = {len(txt_q)}, len(img_q) = {len(img_q)}, txt_seq_len.shape = {txt_seq_len.shape}") + print(f"txt_seq_len[i] = {txt_seq_len[i]}, txt_q.shape = {txt_q[i].shape}, img_q.shape = {img_q[i].shape}") + txt_q_i = txt_q[i][:, :, : txt_seq_len[i]] + txt_q[i] = None + img_q_i = img_q[i] + img_q[i] = None + q = torch.cat((txt_q_i, img_q_i), dim=2) + del txt_q_i, img_q_i - attn = attention(q, k, v, pe=pe, attn_mask=mask) - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] + txt_k_i = txt_k[i][:, :, : txt_seq_len[i]] + txt_k[i] = None + img_k_i = img_k[i] + img_k[i] = None + k = torch.cat((txt_k_i, img_k_i), dim=2) + del txt_k_i, img_k_i + + txt_v_i = txt_v[i][:, :, : txt_seq_len[i]] + txt_v[i] = None + img_v_i = img_v[i] + img_v[i] = None + v = torch.cat((txt_v_i, img_v_i), dim=2) + del txt_v_i, img_v_i + + attn = attention(q, k, v, pe=pe[i], attn_mask=None) # (1, L, D) + print(f"attn.shape = {attn.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}") + txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) + txt_attn_i[:, : txt_seq_len[i], :] = attn[:, : txt_seq_len[i], :] + img_attn_i = attn[:, txt_seq_len[i] :, :] + txt_attn.append(txt_attn_i) + img_attn.append(img_attn_i) + + txt_attn = torch.cat(txt_attn, dim=0) + img_attn = torch.cat(img_attn, dim=0) + + # q = torch.cat((txt_q, img_q), dim=2) + # k = torch.cat((txt_k, img_k), dim=2) + # v = torch.cat((txt_v, img_v), dim=2) + + # attn = attention(q, k, v, pe=pe, attn_mask=mask) + # txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks # replaced with compiled fn @@ -273,12 +318,12 @@ class DoubleStreamBlock(nn.Module): txt: Tensor, pe: Tensor, distill_vec: list[ModulationOut], - mask: Tensor, + txt_seq_len: Tensor, ) -> tuple[Tensor, Tensor]: if self.training and self.gradient_checkpointing: - return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, mask, use_reentrant=False) + return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, txt_seq_len, use_reentrant=False) else: - return self._forward(img, txt, pe, distill_vec, mask) + return self._forward(img, txt, pe, distill_vec, txt_seq_len) class SingleStreamBlock(nn.Module): @@ -332,7 +377,9 @@ class SingleStreamBlock(nn.Module): def disable_gradient_checkpointing(self): self.gradient_checkpointing = False - def _forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor: + def _forward( + self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int + ) -> Tensor: mod = distill_vec # replaced with compiled fn # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift @@ -342,19 +389,44 @@ class SingleStreamBlock(nn.Module): q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) - # compute attention - attn = attention(q, k, v, pe=pe, attn_mask=mask) + # # compute attention + # attn = attention(q, k, v, pe=pe, attn_mask=mask) + + # compute attention: we split the batch into each element + q = list(torch.chunk(q, q.shape[0], dim=0)) + k = list(torch.chunk(k, k.shape[0], dim=0)) + v = list(torch.chunk(v, v.shape[0], dim=0)) + attn = [] + for i in range(x.size(0)): + q_i = torch.cat((q[i][:, :, : txt_seq_len[i]], q[i][:, :, max_txt_len:]), dim=2) + q[i] = None + k_i = torch.cat((k[i][:, :, : txt_seq_len[i]], k[i][:, :, max_txt_len:]), dim=2) + k[i] = None + v_i = torch.cat((v[i][:, :, : txt_seq_len[i]], v[i][:, :, max_txt_len:]), dim=2) + v[i] = None + attn_trimmed = attention(q_i, k_i, v_i, pe=pe[i], attn_mask=None) + print( + f"attn_trimmed.shape = {attn_trimmed.shape}, txt_seq_len[i] = {txt_seq_len[i]}, max_txt_len = {max_txt_len}, x.shape = {x.shape}" + ) + + attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) + attn_i[:, : txt_seq_len[i], :] = attn_trimmed[:, : txt_seq_len[i], :] + attn_i[:, max_txt_len:, :] = attn_trimmed[:, txt_seq_len[i] :, :] + attn.append(attn_i) + + attn = torch.cat(attn, dim=0) + # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) # replaced with compiled fn # return x + mod.gate * output return self.modulation_gate_fn(x, mod.gate, output) - def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], mask: Tensor) -> Tensor: + def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor, max_txt_len: int) -> Tensor: if self.training and self.gradient_checkpointing: - return ckpt.checkpoint(self._forward, x, pe, distill_vec, mask, use_reentrant=False) + return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, max_txt_len, use_reentrant=False) else: - return self._forward(x, pe, distill_vec, mask) + return self._forward(x, pe, distill_vec, txt_seq_len, max_txt_len) class LastLayer(nn.Module): @@ -542,6 +614,29 @@ class Chroma(Flux): self.gradient_checkpointing = False self.cpu_offload_checkpointing = False + def get_mod_vectors( + self, + timesteps: Tensor, + guidance: Tensor | None = None, + batch_size: int | None = None, + requires_grad: bool = False, + ) -> Tensor: + distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) + # TODO: need to add toggle to omit this from schnell but that's not a priority + distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) + # get all modulation index + modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2) + # we need to broadcast the modulation index here so each batch has all of the index + modulation_index = modulation_index.unsqueeze(0).repeat(batch_size, 1, 1) + # and we need to broadcast timestep and guidance along too + timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) + # then and only then we could concatenate it together + input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) + if requires_grad: + input_vec = input_vec.requires_grad_(True) + mod_vectors = self.distilled_guidance_layer(input_vec) + return mod_vectors + def forward( self, img: Tensor, @@ -554,6 +649,8 @@ class Chroma(Flux): block_controlnet_single_hidden_states=None, guidance: Tensor | None = None, txt_attention_mask: Tensor | None = None, + attn_padding: int = 1, + mod_vectors: Tensor | None = None, ) -> Tensor: # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" @@ -567,85 +664,64 @@ class Chroma(Flux): img = self.img_in(img) txt = self.txt_in(txt) - # TODO: - # need to fix grad accumulation issue here for now it's in no grad mode - # besides, i don't want to wash out the PFP that's trained on this model weights anyway - # the fan out operation here is deleting the backward graph - # alternatively doing forward pass for every block manually is doable but slow - # custom backward probably be better - with torch.no_grad(): - distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) - # TODO: need to add toggle to omit this from schnell but that's not a priority - distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) - # get all modulation index - modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2) - # we need to broadcast the modulation index here so each batch has all of the index - modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1) - # and we need to broadcast timestep and guidance along too - timestep_guidance = ( - torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1) - ) - # then and only then we could concatenate it together - input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1) - mod_vectors = self.distilled_guidance_layer(input_vec.requires_grad_(True)) + if mod_vectors is None: + # TODO: + # need to fix grad accumulation issue here for now it's in no grad mode + # besides, i don't want to wash out the PFP that's trained on this model weights anyway + # the fan out operation here is deleting the backward graph + # alternatively doing forward pass for every block manually is doable but slow + # custom backward probably be better + with torch.no_grad(): + # kohya-ss: I'm not sure why requires_grad is set to True here + mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0], requires_grad=True) + mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks) ids = torch.cat((txt_ids, img_ids), dim=1) - pe = self.pe_embedder(ids) + pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2 - # compute mask - # assume max seq length from the batched input + # calculate text length for each batch instead of masking + txt_emb_len = txt.shape[1] + txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1) # (batch_size, ) + txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) + max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch - max_len = txt.shape[1] + # trim txt embedding to the text length + txt = txt[:, :max_txt_len, :] - # mask - with torch.no_grad(): - txt_mask_w_padding = modify_mask_to_attend_padding(txt_attention_mask, max_len, 1) - txt_img_mask = torch.cat( - [ - txt_mask_w_padding, - torch.ones([img.shape[0], img.shape[1]], device=txt_attention_mask.device), - ], - dim=1, - ) - txt_img_mask = txt_img_mask.float().T @ txt_img_mask.float() - txt_img_mask = txt_img_mask[None, None, ...].repeat(txt.shape[0], self.num_heads, 1, 1).int().bool() - # txt_mask_w_padding[txt_mask_w_padding==False] = True + # split positional encoding into each element of the batch, and trim masked tokens + print(f"pe shape = {pe.shape} dtype = {pe.dtype}, txt_seq_len = {txt_seq_len}") + pe = list(torch.chunk(pe, pe.shape[0], dim=0)) + for i in range(len(pe)): + # trim positional encoding to the text length + pe[i] = torch.cat([pe[i][:, :, : txt_seq_len[i]], pe[i][:, :, txt_emb_len:]], dim=2) - if not self.blocks_to_swap: - for i, block in enumerate(self.double_blocks): - # the guidance replaced by FFN output - img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] - double_mod = [img_mod, txt_mod] - - img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask) - else: - for i, block in enumerate(self.double_blocks): + for i, block in enumerate(self.double_blocks): + if self.blocks_to_swap: self.offloader_double.wait_for_block(i) - # the guidance replaced by FFN output - img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] - double_mod = [img_mod, txt_mod] + # the guidance replaced by FFN output + img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] + txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + double_mod = [img_mod, txt_mod] - img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, mask=txt_img_mask) + img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len) + if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, i) img = torch.cat((txt, img), 1) - if not self.blocks_to_swap: - for i, block in enumerate(self.single_blocks): - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) - else: - for i, block in enumerate(self.single_blocks): + + for i, block in enumerate(self.single_blocks): + if self.blocks_to_swap: self.offloader_single.wait_for_block(i) - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] - img = block(img, pe=pe, distill_vec=single_mod, mask=txt_img_mask) + single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len, max_txt_len=max_txt_len) + if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, i) + img = img[:, txt.shape[1] :, ...] final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"] img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels)