make torch.compile happy

don't compile funcs with complex ops

simplify FeedForward to avoid "cache line invalidated" error
This commit is contained in:
urlesistiana
2025-10-01 17:40:12 +08:00
parent f25cb8abd1
commit 45cab086cc

View File

@@ -558,7 +558,7 @@ class JointAttention(nn.Module):
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)
@torch.compiler.disable
@torch.compiler.disable(reason="complex ops inside")
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
@@ -634,14 +634,10 @@ class FeedForward(nn.Module):
bias=False,
)
nn.init.xavier_uniform_(self.w3.weight)
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
return self.w2(F.silu(self.w1(x)*self.w3(x)))
class JointTransformerBlock(GradientCheckpointMixin):
@@ -820,6 +816,7 @@ class RopeEmbedder:
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
@torch.compiler.disable(reason="complex ops inside")
def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
@@ -1232,6 +1229,7 @@ class NextDiT(nn.Module):
return output
@staticmethod
@torch.compiler.disable(reason="complex ops inside")
def precompute_freqs_cis(
dim: List[int],
end: List[int],