mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
make torch.compile happy
don't compile funcs with complex ops simplify FeedForward to avoid "cache line invalidated" error
This commit is contained in:
@@ -558,7 +558,7 @@ class JointAttention(nn.Module):
|
||||
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
|
||||
)
|
||||
|
||||
@torch.compiler.disable
|
||||
@torch.compiler.disable(reason="complex ops inside")
|
||||
def apply_rope(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
@@ -634,14 +634,10 @@ class FeedForward(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
nn.init.xavier_uniform_(self.w3.weight)
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
@torch.compile(disable=disable_selective_torch_compile)
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
return self.w2(F.silu(self.w1(x)*self.w3(x)))
|
||||
|
||||
|
||||
class JointTransformerBlock(GradientCheckpointMixin):
|
||||
@@ -820,6 +816,7 @@ class RopeEmbedder:
|
||||
self.axes_lens = axes_lens
|
||||
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
|
||||
@torch.compiler.disable(reason="complex ops inside")
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
device = ids.device
|
||||
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
|
||||
@@ -1232,6 +1229,7 @@ class NextDiT(nn.Module):
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.compiler.disable(reason="complex ops inside")
|
||||
def precompute_freqs_cis(
|
||||
dim: List[int],
|
||||
end: List[int],
|
||||
|
||||
Reference in New Issue
Block a user