change option names, add ddp kwargs if needed ref #1000

This commit is contained in:
Kohya S
2023-12-13 21:02:26 +09:00
parent 471d274803
commit d309a27a51

View File

@@ -2900,10 +2900,14 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト分、Noneでaccelerateのデフォルト",
)
parser.add_argument(
"--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP",
"--ddp_gradient_as_bucket_view",
action="store_true",
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
)
parser.add_argument(
"--static_graph", action="store_true", help="enable static_graph for DDP",
"--ddp_static_graph",
action="store_true",
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
)
parser.add_argument(
"--clip_skip",
@@ -3866,10 +3870,11 @@ def prepare_accelerator(args: argparse.Namespace):
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
kwargs_handlers = (
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph)
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None,
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator(