feat: support torch.compile

This commit is contained in:
Plat
2023-12-27 02:13:37 +09:00
parent 20296b4f0e
commit 62e7516537

View File

@@ -2848,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true", action="store_true",
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う", help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
) )
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
parser.add_argument(
"--dynamo_backend",
type=str,
default="inductor",
# available backends:
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
# https://pytorch.org/docs/stable/torch.compiler.html
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
help="dynamo backend type (default is inductor) / dynamoのbackendの種類デフォルトは inductor"
)
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う") parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
parser.add_argument( parser.add_argument(
"--sdpa", "--sdpa",
@@ -3870,6 +3881,11 @@ def prepare_accelerator(args: argparse.Namespace):
if args.wandb_api_key is not None: if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key) wandb.login(key=args.wandb_api_key)
# torch.compile のオプション。 NO の場合は torch.compile は使わない
dynamo_backend = "NO"
if args.torch_compile:
dynamo_backend = args.dynamo_backend
kwargs_handlers = ( kwargs_handlers = (
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
@@ -3883,6 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace):
log_with=log_with, log_with=log_with,
project_dir=logging_dir, project_dir=logging_dir,
kwargs_handlers=kwargs_handlers, kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
) )
return accelerator return accelerator