diff --git a/library/lumina_models.py b/library/lumina_models.py index 7881726e..84fa44c5 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -637,7 +637,7 @@ class FeedForward(nn.Module): @torch.compile(disable=disable_selective_torch_compile) def forward(self, x): - return self.w2(F.silu(self.w1(x)*self.w3(x))) + return self.w2(F.silu(self.w1(x))*self.w3(x)) class JointTransformerBlock(GradientCheckpointMixin):