mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: fix timestep for input_vec for Chroma
This commit is contained in:
@@ -341,9 +341,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
|
||||
# get modulation vectors for Chroma
|
||||
input_vec = None
|
||||
if self.model_type == "chroma":
|
||||
input_vec = unet.get_input_vec(timesteps=timesteps, guidance=guidance_vec, batch_size=bsz)
|
||||
input_vec = unet.get_input_vec(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
|
||||
@@ -223,7 +223,10 @@ class DoubleStreamBlock(nn.Module):
|
||||
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
|
||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
@@ -232,7 +235,10 @@ class DoubleStreamBlock(nn.Module):
|
||||
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
|
||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention: we split the batch into each element
|
||||
@@ -263,9 +269,11 @@ class DoubleStreamBlock(nn.Module):
|
||||
img_v[i] = None
|
||||
|
||||
attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D)
|
||||
del q, k, v
|
||||
img_attn_i = attn[:, :img_len, :]
|
||||
txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device)
|
||||
txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :]
|
||||
del attn
|
||||
txt_attn.append(txt_attn_i)
|
||||
img_attn.append(img_attn_i)
|
||||
|
||||
@@ -279,27 +287,31 @@ class DoubleStreamBlock(nn.Module):
|
||||
# attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
# txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
# calculate the img blocks
|
||||
# replaced with compiled fn
|
||||
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
del img_attn, img_mod1
|
||||
img = self.modulation_gate_fn(
|
||||
img,
|
||||
img_mod2.gate,
|
||||
self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)),
|
||||
)
|
||||
del img_mod2
|
||||
|
||||
# calculate the txt bloks
|
||||
# calculate the txt blocks
|
||||
# replaced with compiled fn
|
||||
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
del txt_attn, txt_mod1
|
||||
txt = self.modulation_gate_fn(
|
||||
txt,
|
||||
txt_mod2.gate,
|
||||
self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)),
|
||||
)
|
||||
del txt_mod2
|
||||
|
||||
return img, txt
|
||||
|
||||
@@ -374,8 +386,10 @@ class SingleStreamBlock(nn.Module):
|
||||
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
del x_mod
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# # compute attention
|
||||
@@ -399,12 +413,15 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device)
|
||||
attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed
|
||||
del attn_trimmed
|
||||
attn.append(attn_i)
|
||||
|
||||
attn = torch.cat(attn, dim=0)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
del attn, mlp
|
||||
# replaced with compiled fn
|
||||
# return x + mod.gate * output
|
||||
return self.modulation_gate_fn(x, mod.gate, output)
|
||||
@@ -625,6 +642,7 @@ class Chroma(Flux):
|
||||
print("Chroma: Gradient checkpointing disabled.")
|
||||
|
||||
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
# print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}")
|
||||
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
|
||||
# TODO: need to add toggle to omit this from schnell but that's not a priority
|
||||
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
|
||||
@@ -656,6 +674,7 @@ class Chroma(Flux):
|
||||
# print(
|
||||
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
|
||||
# )
|
||||
# print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}")
|
||||
# print(f"timesteps: {timesteps}, guidance: {guidance}")
|
||||
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@@ -687,6 +706,7 @@ class Chroma(Flux):
|
||||
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, )
|
||||
txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len)
|
||||
max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch
|
||||
# print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}")
|
||||
|
||||
# trim txt embedding to the text length
|
||||
txt = txt[:, :max_txt_len, :]
|
||||
@@ -700,23 +720,27 @@ class Chroma(Flux):
|
||||
self.offloader_double.wait_for_block(i)
|
||||
|
||||
# the guidance replaced by FFN output
|
||||
img_mod = mod_vectors_dict[f"double_blocks.{i}.img_mod.lin"]
|
||||
txt_mod = mod_vectors_dict[f"double_blocks.{i}.txt_mod.lin"]
|
||||
img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin")
|
||||
txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin")
|
||||
double_mod = [img_mod, txt_mod]
|
||||
del img_mod, txt_mod
|
||||
|
||||
img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len)
|
||||
del double_mod
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_double.submit_move_blocks(self.double_blocks, i)
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
del txt
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_single.wait_for_block(i)
|
||||
|
||||
single_mod = mod_vectors_dict[f"single_blocks.{i}.modulation.lin"]
|
||||
single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin")
|
||||
img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len)
|
||||
del single_mod
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.offloader_single.submit_move_blocks(self.single_blocks, i)
|
||||
|
||||
@@ -1009,6 +1009,9 @@ class Flux(nn.Module):
|
||||
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
||||
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
||||
|
||||
def get_input_vec(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
|
||||
return None # FLUX.1 does not use input_vec, but Chroma does.
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: Tensor,
|
||||
|
||||
Reference in New Issue
Block a user