add --activation_memory_budget to training arguments

remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env
This commit is contained in:
urlesistiana
2025-09-30 22:03:30 +08:00
parent c15e6b4f3b
commit f25cb8abd1
2 changed files with 12 additions and 5 deletions

View File

@@ -35,11 +35,6 @@ from library import custom_offloading_utils
disable_selective_torch_compile = (
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
)
memory_budget = float(
os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0")
)
if memory_budget > 0:
torch._functorch.config.activation_memory_budget = memory_budget
try:
from flash_attn import flash_attn_varlen_func

View File

@@ -3974,6 +3974,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor",
)
parser.add_argument(
"--activation_memory_budget",
type=float,
default=None,
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定01の値。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument(
"--sdpa",
@@ -5506,6 +5512,12 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile:
dynamo_backend = args.dynamo_backend
if args.activation_memory_budget:
logger.info(
f"set torch compile activation memory budget to {args.activation_memory_budget}"
)
torch._functorch.config.activation_memory_budget = args.activation_memory_budget # type: ignore
kwargs_handlers = [
(
InitProcessGroupKwargs(