This commit is contained in:
Urle Sistiana
2026-04-05 00:38:08 +00:00
committed by GitHub
2 changed files with 32 additions and 7 deletions

View File

@@ -20,6 +20,7 @@
# --------------------------------------------------------
import math
import os
from typing import List, Optional, Tuple
from dataclasses import dataclass
@@ -31,6 +32,10 @@ import torch.nn.functional as F
from library import custom_offloading_utils
disable_selective_torch_compile = (
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
)
try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
@@ -549,7 +554,7 @@ class JointAttention(nn.Module):
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)
@torch.compiler.disable(reason="complex ops inside")
def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
@@ -625,13 +630,10 @@ class FeedForward(nn.Module):
bias=False,
)
nn.init.xavier_uniform_(self.w3.weight)
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
return self.w2(F.silu(self.w1(x))*self.w3(x))
class JointTransformerBlock(GradientCheckpointMixin):
@@ -697,6 +699,7 @@ class JointTransformerBlock(GradientCheckpointMixin):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
@torch.compile(disable=disable_selective_torch_compile)
def _forward(
self,
x: torch.Tensor,
@@ -788,6 +791,7 @@ class FinalLayer(GradientCheckpointMixin):
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
@torch.compile(disable=disable_selective_torch_compile)
def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
@@ -808,6 +812,7 @@ class RopeEmbedder:
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
@torch.compiler.disable(reason="complex ops inside")
def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
@@ -1219,6 +1224,7 @@ class NextDiT(nn.Module):
return output
@staticmethod
@torch.compiler.disable(reason="complex ops inside")
def precompute_freqs_cis(
dim: List[int],
end: List[int],

View File

@@ -3992,6 +3992,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument(
"--activation_memory_budget",
type=float,
default=None,
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定01の値。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
@@ -5539,6 +5545,19 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile:
dynamo_backend = args.dynamo_backend
if args.activation_memory_budget is not None: # Note: 0 is a valid value.
if 0 <= args.activation_memory_budget <= 1:
logger.info(
f"set torch compile activation memory budget to {args.activation_memory_budget}"
)
torch._functorch.config.activation_memory_budget = ( # type: ignore
args.activation_memory_budget
)
else:
raise ValueError(
"activation_memory_budget must be between 0 and 1 (inclusive)"
)
kwargs_handlers = [
(
InitProcessGroupKwargs(