diff --git a/library/lumina_models.py b/library/lumina_models.py index d12a9922..7881726e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -558,7 +558,7 @@ class JointAttention(nn.Module): f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" ) -@torch.compiler.disable +@torch.compiler.disable(reason="complex ops inside") def apply_rope( x_in: torch.Tensor, freqs_cis: torch.Tensor, @@ -634,14 +634,10 @@ class FeedForward(nn.Module): bias=False, ) nn.init.xavier_uniform_(self.w3.weight) - - # @torch.compile - def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 @torch.compile(disable=disable_selective_torch_compile) def forward(self, x): - return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + return self.w2(F.silu(self.w1(x)*self.w3(x))) class JointTransformerBlock(GradientCheckpointMixin): @@ -820,6 +816,7 @@ class RopeEmbedder: self.axes_lens = axes_lens self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + @torch.compiler.disable(reason="complex ops inside") def __call__(self, ids: torch.Tensor): device = ids.device self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] @@ -1232,6 +1229,7 @@ class NextDiT(nn.Module): return output @staticmethod + @torch.compiler.disable(reason="complex ops inside") def precompute_freqs_cis( dim: List[int], end: List[int],