torch 2.4

This commit is contained in:
Akegarasu
2024-10-10 14:08:55 +08:00
parent 3de42b6edb
commit 9f4dac5731

View File

@@ -33,6 +33,7 @@ from io import BytesIO
import toml import toml
from tqdm import tqdm from tqdm import tqdm
from packaging.version import Version
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
@@ -5048,7 +5049,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers = [ kwargs_handlers = [
InitProcessGroupKwargs( InitProcessGroupKwargs(
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl", backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
init_method="env://?use_libuv=False" if os.name == "nt" else None, init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
) if torch.cuda.device_count() > 1 else None, ) if torch.cuda.device_count() > 1 else None,
DistributedDataParallelKwargs( DistributedDataParallelKwargs(