From 3de42b6edb151b172f483aec99fe380b1406a84a Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 10 Oct 2024 14:03:59 +0800 Subject: [PATCH 1/2] fix: distributed training in windows --- library/train_util.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e023f63a..3dabf9e2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5045,17 +5045,18 @@ def prepare_accelerator(args: argparse.Namespace): if args.torch_compile: dynamo_backend = args.dynamo_backend - kwargs_handlers = ( - InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, - ( - DistributedDataParallelKwargs( - gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph - ) - if args.ddp_gradient_as_bucket_view or args.ddp_static_graph - else None - ), - ) - kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) + kwargs_handlers = [ + InitProcessGroupKwargs( + backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", + init_method="env://?use_libuv=False" if os.name == "nt" else None, + timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None + ) if torch.cuda.device_count() > 1 else None, + DistributedDataParallelKwargs( + gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, + static_graph=args.ddp_static_graph + ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None + ] + kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) accelerator = Accelerator( From 9f4dac5731fe2299c75b7671c6132febd57a4117 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 10 Oct 2024 14:08:55 +0800 Subject: [PATCH 2/2] torch 2.4 --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 3dabf9e2..2c20a924 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -33,6 +33,7 @@ from io import BytesIO import toml from tqdm import tqdm +from packaging.version import Version import torch from library.device_utils import init_ipex, clean_memory_on_device @@ -5048,7 +5049,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [ InitProcessGroupKwargs( backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", - init_method="env://?use_libuv=False" if os.name == "nt" else None, + init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None, timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None ) if torch.cuda.device_count() > 1 else None, DistributedDataParallelKwargs(