fix: inference for Chroma model

This commit is contained in:
Kohya S
2025-07-20 14:08:54 +09:00
parent 24d2ea86c7
commit 404ddb060d
3 changed files with 23 additions and 18 deletions

View File

@@ -240,7 +240,7 @@ class DoubleStreamBlock(nn.Module):
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe, mask=mask)
attn = attention(q, k, v, pe=pe, attn_mask=mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
@@ -343,7 +343,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=mask)
attn = attention(q, k, v, pe=pe, attn_mask=mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
# replaced with compiled fn
@@ -555,6 +555,11 @@ class Chroma(Flux):
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
# print(
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
# )
# print(f"timesteps: {timesteps}, guidance: {guidance}")
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")

View File

@@ -146,7 +146,7 @@ def load_flow_model(
from . import chroma_models
# build model
logger.info("Building Chroma model from BFL checkpoint")
logger.info("Building Chroma model")
with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None: