From c15e6b4f3bc45eefb541fca06ee2e575494235b0 Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Tue, 30 Sep 2025 11:41:53 +0800 Subject: [PATCH 1/5] feat: add selective torch compile and activation memory budget to Lumina 2 --- library/lumina_models.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 7e925352..be60489e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -20,6 +20,7 @@ # -------------------------------------------------------- import math +import os from typing import List, Optional, Tuple from dataclasses import dataclass @@ -31,6 +32,15 @@ import torch.nn.functional as F from library import custom_offloading_utils +disable_selective_torch_compile = ( + os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0" +) +memory_budget = float( + os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0") +) +if memory_budget > 0: + torch._functorch.config.activation_memory_budget = memory_budget + try: from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -553,7 +563,7 @@ class JointAttention(nn.Module): f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" ) - +@torch.compiler.disable def apply_rope( x_in: torch.Tensor, freqs_cis: torch.Tensor, @@ -633,7 +643,8 @@ class FeedForward(nn.Module): # @torch.compile def _forward_silu_gating(self, x1, x3): return F.silu(x1) * x3 - + + @torch.compile(disable=disable_selective_torch_compile) def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) @@ -701,6 +712,7 @@ class JointTransformerBlock(GradientCheckpointMixin): nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) + @torch.compile(disable=disable_selective_torch_compile) def _forward( self, x: torch.Tensor, @@ -792,6 +804,7 @@ class FinalLayer(GradientCheckpointMixin): nn.init.zeros_(self.adaLN_modulation[1].weight) nn.init.zeros_(self.adaLN_modulation[1].bias) + @torch.compile(disable=disable_selective_torch_compile) def forward(self, x, c): scale = self.adaLN_modulation(c) x = modulate(self.norm_final(x), scale) From f25cb8abd1dc1cbd1a1eb844b58444cddce4079d Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Tue, 30 Sep 2025 22:03:30 +0800 Subject: [PATCH 2/5] add --activation_memory_budget to training arguments remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env --- library/lumina_models.py | 5 ----- library/train_util.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index be60489e..d12a9922 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -35,11 +35,6 @@ from library import custom_offloading_utils disable_selective_torch_compile = ( os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0" ) -memory_budget = float( - os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0") -) -if memory_budget > 0: - torch._functorch.config.activation_memory_budget = memory_budget try: from flash_attn import flash_attn_varlen_func diff --git a/library/train_util.py b/library/train_util.py index 756d88b1..8a54cd0c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3974,6 +3974,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) + parser.add_argument( + "--activation_memory_budget", + type=float, + default=None, + help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定(0~1の値)。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。" + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -5506,6 +5512,12 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend + if args.activation_memory_budget: + logger.info( + f"set torch compile activation memory budget to {args.activation_memory_budget}" + ) + torch._functorch.config.activation_memory_budget = args.activation_memory_budget # type: ignore + kwargs_handlers = [ ( InitProcessGroupKwargs( From 45cab086cc6e45566ad39ac99cc7be38ebb69f8e Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:40:12 +0800 Subject: [PATCH 3/5] make torch.compile happy don't compile funcs with complex ops simplify FeedForward to avoid "cache line invalidated" error --- library/lumina_models.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index d12a9922..7881726e 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -558,7 +558,7 @@ class JointAttention(nn.Module): f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}" ) -@torch.compiler.disable +@torch.compiler.disable(reason="complex ops inside") def apply_rope( x_in: torch.Tensor, freqs_cis: torch.Tensor, @@ -634,14 +634,10 @@ class FeedForward(nn.Module): bias=False, ) nn.init.xavier_uniform_(self.w3.weight) - - # @torch.compile - def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 @torch.compile(disable=disable_selective_torch_compile) def forward(self, x): - return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + return self.w2(F.silu(self.w1(x)*self.w3(x))) class JointTransformerBlock(GradientCheckpointMixin): @@ -820,6 +816,7 @@ class RopeEmbedder: self.axes_lens = axes_lens self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta) + @torch.compiler.disable(reason="complex ops inside") def __call__(self, ids: torch.Tensor): device = ids.device self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis] @@ -1232,6 +1229,7 @@ class NextDiT(nn.Module): return output @staticmethod + @torch.compiler.disable(reason="complex ops inside") def precompute_freqs_cis( dim: List[int], end: List[int], From 3420a6f7d165082241db57aba2754c7c07f7cd80 Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:55:24 +0800 Subject: [PATCH 4/5] check activation_memory_budget value range and accept value 0 --- library/train_util.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 8a54cd0c..67df2258 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5512,11 +5512,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - if args.activation_memory_budget: - logger.info( - f"set torch compile activation memory budget to {args.activation_memory_budget}" - ) - torch._functorch.config.activation_memory_budget = args.activation_memory_budget # type: ignore + if args.activation_memory_budget is not None: # Note: 0 is a valid value. + if 0 <= args.activation_memory_budget <= 1: + logger.info( + f"set torch compile activation memory budget to {args.activation_memory_budget}" + ) + torch._functorch.config.activation_memory_budget = ( # type: ignore + args.activation_memory_budget + ) + else: + raise ValueError( + "activation_memory_budget must be between 0 and 1 (inclusive)" + ) kwargs_handlers = [ ( From 3bbfa9b258fc7e9a7ca0e8ce17376c298e3c1fb4 Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Wed, 1 Oct 2025 18:00:16 +0800 Subject: [PATCH 5/5] fixed FeedForward --- library/lumina_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index 7881726e..84fa44c5 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -637,7 +637,7 @@ class FeedForward(nn.Module): @torch.compile(disable=disable_selective_torch_compile) def forward(self, x): - return self.w2(F.silu(self.w1(x)*self.w3(x))) + return self.w2(F.silu(self.w1(x))*self.w3(x)) class JointTransformerBlock(GradientCheckpointMixin):