From d309a27a5117ec088cc45e663be11036da3a6ba5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 13 Dec 2023 21:02:26 +0900 Subject: [PATCH] change option names, add ddp kwargs if needed ref #1000 --- library/train_util.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b60168f3..aed21f65 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2900,10 +2900,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)", ) parser.add_argument( - "--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP", + "--ddp_gradient_as_bucket_view", + action="store_true", + help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする", ) parser.add_argument( - "--static_graph", action="store_true", help="enable static_graph for DDP", + "--ddp_static_graph", + action="store_true", + help="enable static_graph for DDP / DDPでstatic_graphを有効にする", ) parser.add_argument( "--clip_skip", @@ -3866,10 +3870,11 @@ def prepare_accelerator(args: argparse.Namespace): if args.wandb_api_key is not None: wandb.login(key=args.wandb_api_key) - kwargs_handlers = ( InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph) + DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph) + if args.ddp_gradient_as_bucket_view or args.ddp_static_graph + else None, ) kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) accelerator = Accelerator(