mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
fix: inference for Chroma model
This commit is contained in:
@@ -240,7 +240,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
@@ -343,7 +343,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=mask)
|
||||
attn = attention(q, k, v, pe=pe, attn_mask=mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
# replaced with compiled fn
|
||||
@@ -555,6 +555,11 @@ class Chroma(Flux):
|
||||
guidance: Tensor | None = None,
|
||||
txt_attention_mask: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
# print(
|
||||
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
|
||||
# )
|
||||
# print(f"timesteps: {timesteps}, guidance: {guidance}")
|
||||
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ def load_flow_model(
|
||||
from . import chroma_models
|
||||
|
||||
# build model
|
||||
logger.info("Building Chroma model from BFL checkpoint")
|
||||
logger.info("Building Chroma model")
|
||||
with torch.device("meta"):
|
||||
model = chroma_models.Chroma(chroma_models.chroma_params)
|
||||
if dtype is not None:
|
||||
|
||||
Reference in New Issue
Block a user