From 0b763ef1f17fc9117b630c3478c6ae02437ac07e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 20 Jul 2025 20:53:06 +0900 Subject: [PATCH] feat: fix timestep for input_vec for Chroma --- flux_train_network.py | 4 +--- library/chroma_models.py | 36 ++++++++++++++++++++++++++++++------ library/flux_models.py | 3 +++ 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/flux_train_network.py b/flux_train_network.py index 1b61ac72..13e9ae2a 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -341,9 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device) # get modulation vectors for Chroma - input_vec = None - if self.model_type == "chroma": - input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz) + input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz) if args.gradient_checkpointing: noisy_model_input.requires_grad_(True) diff --git a/library/chroma_models.py b/library/chroma_models.py index e5d3b547..b9c54db4 100644 --- a/library/chroma_models.py +++ b/library/chroma_models.py @@ -223,7 +223,10 @@ class DoubleStreamBlock(nn.Module): # img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift) img_qkv = self.img_attn.qkv(img_modulated) + del img_modulated + img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention @@ -232,7 +235,10 @@ class DoubleStreamBlock(nn.Module): # txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift) txt_qkv = self.txt_attn.qkv(txt_modulated) + del txt_modulated + txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention: we split the batch into each element @@ -263,9 +269,11 @@ class DoubleStreamBlock(nn.Module): img_v[i] = None attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D) + del q, k, v img_attn_i = attn[:, :img_len, :] txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device) txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :] + del attn txt_attn.append(txt_attn_i) img_attn.append(img_attn_i) @@ -279,27 +287,31 @@ class DoubleStreamBlock(nn.Module): # attn = attention(q, k, v, pe=pe, attn_mask=mask) # txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - # calculate the img bloks + # calculate the img blocks # replaced with compiled fn # img = img + img_mod1.gate * self.img_attn.proj(img_attn) # img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn)) + del img_attn, img_mod1 img = self.modulation_gate_fn( img, img_mod2.gate, self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)), ) + del img_mod2 - # calculate the txt bloks + # calculate the txt blocks # replaced with compiled fn # txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) # txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn)) + del txt_attn, txt_mod1 txt = self.modulation_gate_fn( txt, txt_mod2.gate, self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)), ) + del txt_mod2 return img, txt @@ -374,8 +386,10 @@ class SingleStreamBlock(nn.Module): # x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + del x_mod q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) + del qkv q, k = self.norm(q, k, v) # # compute attention @@ -399,12 +413,15 @@ class SingleStreamBlock(nn.Module): attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device) attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed + del attn_trimmed attn.append(attn_i) attn = torch.cat(attn, dim=0) # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) + del attn, mlp # replaced with compiled fn # return x + mod.gate * output return self.modulation_gate_fn(x, mod.gate, output) @@ -625,6 +642,7 @@ class Chroma(Flux): print("Chroma: Gradient checkpointing disabled.") def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + # print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}") distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4) # TODO: need to add toggle to omit this from schnell but that's not a priority distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4) @@ -656,6 +674,7 @@ class Chroma(Flux): # print( # f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}" # ) + # print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}") # print(f"timesteps: {timesteps}, guidance: {guidance}") if img.ndim != 3 or txt.ndim != 3: @@ -687,6 +706,7 @@ class Chroma(Flux): txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, ) txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len) max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch + # print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}") # trim txt embedding to the text length txt = txt[:, :max_txt_len, :] @@ -700,23 +720,27 @@ class Chroma(Flux): self.offloader_double.wait_for_block(i) # the guidance replaced by FFN output - img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"] - txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"] + img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin") + txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin") double_mod = [img_mod, txt_mod] + del img_mod, txt_mod img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len) + del double_mod if self.blocks_to_swap: self.offloader_double.submit_move_blocks(self.double_blocks, i) img = torch.cat((img, txt), 1) + del txt for i, block in enumerate(self.single_blocks): if self.blocks_to_swap: self.offloader_single.wait_for_block(i) - single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"] + single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin") img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len) + del single_mod if self.blocks_to_swap: self.offloader_single.submit_move_blocks(self.single_blocks, i) diff --git a/library/flux_models.py b/library/flux_models.py index 6f889755..2a2fe5f8 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1009,6 +1009,9 @@ class Flux(nn.Module): self.offloader_double.prepare_block_devices_before_forward(self.double_blocks) self.offloader_single.prepare_block_devices_before_forward(self.single_blocks) + def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor: + return None # FLUX.1 does not use input_vec, but Chroma does. + def forward( self, img: Tensor,