diff --git a/library/lumina_models.py b/library/lumina_models.py index 7e925352..be60489e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -20,6 +20,7 @@ # -------------------------------------------------------- import math +import os from typing import List, Optional, Tuple from dataclasses import dataclass @@ -31,6 +32,15 @@ import torch.nn.functional as F from library import custom_offloading_utils +disable_selective_torch_compile = ( + os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0" +) +memory_budget = float( + os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0") +) +if memory_budget > 0: + torch._functorch.config.activation_memory_budget = memory_budget + try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -553,7 +563,7 @@ class JointAttention(nn.Module): f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" ) - +@torch.compiler.disable def apply_rope( x_in: torch.Tensor, freqs_cis: torch.Tensor, @@ -633,7 +643,8 @@ class FeedForward(nn.Module): # @torch.compile def _forward_silu_gating(self, x1, x3): return F.silu(x1) * x3 - + + @torch.compile(disable=disable_selective_torch_compile) def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) @@ -701,6 +712,7 @@ class JointTransformerBlock(GradientCheckpointMixin): nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) + @torch.compile(disable=disable_selective_torch_compile) def _forward( self, x: torch.Tensor, @@ -792,6 +804,7 @@ class FinalLayer(GradientCheckpointMixin): nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) + @torch.compile(disable=disable_selective_torch_compile) def forward(self, x, c): scale = self.adaLN_modulation(c) x = modulate(self.norm_final(x), scale)