mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
add --activation_memory_budget to training arguments
remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env
This commit is contained in:
@@ -35,11 +35,6 @@ from library import custom_offloading_utils
|
||||
disable_selective_torch_compile = (
|
||||
os.getenv("SDSCRIPTS_SELECTIVE_TORCH_COMPILE", "0") == "0"
|
||||
)
|
||||
memory_budget = float(
|
||||
os.getenv("SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET", "0")
|
||||
)
|
||||
if memory_budget > 0:
|
||||
torch._functorch.config.activation_memory_budget = memory_budget
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
@@ -3974,6 +3974,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--activation_memory_budget",
|
||||
type=float,
|
||||
default=None,
|
||||
help="activation memory budget setting for torch.compile (range: 0~1). Smaller value saves more memory at cost of speed. If set, use --torch_compile without --gradient_checkpointing is recommended. Requires PyTorch 2.4. / torch.compileのactivation memory budget設定(0~1の値)。この値を小さくするとメモリ使用量を節約できますが、処理速度は低下します。この設定を行う場合は、--gradient_checkpointing オプションを指定せずに --torch_compile を使用することをお勧めします。PyTorch 2.4以降が必要です。"
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
"--sdpa",
|
||||
@@ -5506,6 +5512,12 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
if args.torch_compile:
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
if args.activation_memory_budget:
|
||||
logger.info(
|
||||
f"set torch compile activation memory budget to {args.activation_memory_budget}"
|
||||
)
|
||||
torch._functorch.config.activation_memory_budget = args.activation_memory_budget # type: ignore
|
||||
|
||||
kwargs_handlers = [
|
||||
(
|
||||
InitProcessGroupKwargs(
|
||||
|
||||
Reference in New Issue
Block a user