mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
feat: support torch.compile
This commit is contained in:
@@ -2848,6 +2848,17 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
action="store_true",
|
||||
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う",
|
||||
)
|
||||
parser.add_argument("--torch_compile", action="store_true", help="use torch.compile (requires PyTorch 2.0) / torch.compile を使う")
|
||||
parser.add_argument(
|
||||
"--dynamo_backend",
|
||||
type=str,
|
||||
default="inductor",
|
||||
# available backends:
|
||||
# https://github.com/huggingface/accelerate/blob/d1abd59114ada8ba673e1214218cb2878c13b82d/src/accelerate/utils/dataclasses.py#L376-L388C5
|
||||
# https://pytorch.org/docs/stable/torch.compiler.html
|
||||
choices=["eager", "aot_eager", "inductor", "aot_ts_nvfuser", "nvprims_nvfuser", "cudagraphs", "ofi", "fx2trt", "onnxrt"],
|
||||
help="dynamo backend type (default is inductor) / dynamoのbackendの種類(デフォルトは inductor)"
|
||||
)
|
||||
parser.add_argument("--xformers", action="store_true", help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
|
||||
parser.add_argument(
|
||||
"--sdpa",
|
||||
@@ -3869,6 +3880,11 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
os.environ["WANDB_DIR"] = logging_dir
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
# torch.compile のオプション。 NO の場合は torch.compile は使わない
|
||||
dynamo_backend = "NO"
|
||||
if args.torch_compile:
|
||||
dynamo_backend = args.dynamo_backend
|
||||
|
||||
kwargs_handlers = (
|
||||
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||
@@ -3883,6 +3899,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
log_with=log_with,
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
)
|
||||
return accelerator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user