fixed FeedForward

This commit is contained in:
urlesistiana
2025-10-01 18:00:16 +08:00
parent 3420a6f7d1
commit 3bbfa9b258

View File

@@ -637,7 +637,7 @@ class FeedForward(nn.Module):
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x):
return self.w2(F.silu(self.w1(x)*self.w3(x)))
return self.w2(F.silu(self.w1(x))*self.w3(x))
class JointTransformerBlock(GradientCheckpointMixin):