From f25cb8abd1dc1cbd1a1eb844b58444cddce4079d Mon Sep 17 00:00:00 2001 From: urlesistiana <55231606+urlesistiana@users.noreply.github.com> Date: Tue, 30 Sep 2025 22:03:30 +0800 Subject: [PATCH] add --activation_memory_budget to training arguments remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env --- library/lumina_models.py | 5 ----- library/train_util.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/library/lumina_models.py b/library/lumina_models.py index be60489e..d12a9922 100644 --- a/library/lumina_models.py +++ b/library/lumina_models.py @@ -35,11 +35,6 @@ from library import custom_offloading_utils disable_selective_torch_compile = ( os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0" ) -memory_budget = float( - os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0") -) -if memory_budget > 0: - torch._functorch.config.activation_memory_budget = memory_budget try: from flash_attn import flash_attn_varlen_func diff --git a/library/train_util.py b/library/train_util.py index 756d88b1..8a54cd0c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3974,6 +3974,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ], help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)", ) + parser.add_argument( + "--activation_memory_budget", + type=float, + default=None, + help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定(0~1の値)。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。" + ) parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument( "--sdpa", @@ -5506,6 +5512,12 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend + if args.activation_memory_budget: + logger.info( + f"set torch compile activation memory budget to {args.activation_memory_budget}" + ) + torch._functorch.config.activation_memory_budget = args.activation_memory_budget # type: ignore + kwargs_handlers = [ ( InitProcessGroupKwargs(