feat: fix timestep for input_vec for Chroma

This commit is contained in:
Kohya S
2025-07-20 20:53:06 +09:00
parent b4e862626a
commit 0b763ef1f1
3 changed files with 34 additions and 9 deletions

View File

@@ -341,9 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
# get modulation vectors for Chroma
input_vec = None
if self.model_type == "chroma":
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)

View File

@@ -223,7 +223,10 @@ class DoubleStreamBlock(nn.Module):
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift)
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
@@ -232,7 +235,10 @@ class DoubleStreamBlock(nn.Module):
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift)
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention: we split the batch into each element
@@ -263,9 +269,11 @@ class DoubleStreamBlock(nn.Module):
img_v[i] = None
attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D)
del q, k, v
img_attn_i = attn[:, :img_len, :]
txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device)
txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :]
del attn
txt_attn.append(txt_attn_i)
img_attn.append(img_attn_i)
@@ -279,27 +287,31 @@ class DoubleStreamBlock(nn.Module):
# attn = attention(q, k, v, pe=pe, attn_mask=mask)
# txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
# calculate the img blocks
# replaced with compiled fn
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
del img_attn, img_mod1
img = self.modulation_gate_fn(
img,
img_mod2.gate,
self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)),
)
del img_mod2
# calculate the txt bloks
# calculate the txt blocks
# replaced with compiled fn
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
del txt_attn, txt_mod1
txt = self.modulation_gate_fn(
txt,
txt_mod2.gate,
self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)),
)
del txt_mod2
return img, txt
@@ -374,8 +386,10 @@ class SingleStreamBlock(nn.Module):
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
del x_mod
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del qkv
q, k = self.norm(q, k, v)
# # compute attention
@@ -399,12 +413,15 @@ class SingleStreamBlock(nn.Module):
attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device)
attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed
del attn_trimmed
attn.append(attn_i)
attn = torch.cat(attn, dim=0)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
del attn, mlp
# replaced with compiled fn
# return x + mod.gate * output
return self.modulation_gate_fn(x, mod.gate, output)
@@ -625,6 +642,7 @@ class Chroma(Flux):
print("Chroma: Gradient checkpointing disabled.")
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
# print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}")
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
# TODO: need to add toggle to omit this from schnell but that's not a priority
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
@@ -656,6 +674,7 @@ class Chroma(Flux):
# print(
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
# )
# print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}")
# print(f"timesteps: {timesteps}, guidance: {guidance}")
if img.ndim != 3 or txt.ndim != 3:
@@ -687,6 +706,7 @@ class Chroma(Flux):
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, )
txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len)
max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch
# print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}")
# trim txt embedding to the text length
txt = txt[:, :max_txt_len, :]
@@ -700,23 +720,27 @@ class Chroma(Flux):
self.offloader_double.wait_for_block(i)
# the guidance replaced by FFN output
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin")
txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin")
double_mod = [img_mod, txt_mod]
del img_mod, txt_mod
img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len)
del double_mod
if self.blocks_to_swap:
self.offloader_double.submit_move_blocks(self.double_blocks, i)
img = torch.cat((img, txt), 1)
del txt
for i, block in enumerate(self.single_blocks):
if self.blocks_to_swap:
self.offloader_single.wait_for_block(i)
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin")
img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len)
del single_mod
if self.blocks_to_swap:
self.offloader_single.submit_move_blocks(self.single_blocks, i)

View File

@@ -1009,6 +1009,9 @@ class Flux(nn.Module):
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
return None # FLUX.1 does not use input_vec, but Chroma does.
def forward(
self,
img: Tensor,