Merge pull request #1686 from Akegarasu/sd3

fix: fix some distributed training error in windows
This commit is contained in:
Kohya S.
2024-10-12 14:33:02 +09:00
committed by GitHub

View File

@@ -34,6 +34,7 @@ import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed # from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm from tqdm import tqdm
from packaging.version import Version
import torch import torch
from library.device_utils import init_ipex, clean_memory_on_device from library.device_utils import init_ipex, clean_memory_on_device
@@ -5077,17 +5078,18 @@ def prepare_accelerator(args: argparse.Namespace):
if args.torch_compile: if args.torch_compile:
dynamo_backend = args.dynamo_backend dynamo_backend = args.dynamo_backend
kwargs_handlers = ( kwargs_handlers = [
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None, InitProcessGroupKwargs(
( backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
DistributedDataParallelKwargs( init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
) ) if torch.cuda.device_count() > 1 else None,
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph DistributedDataParallelKwargs(
else None gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
), static_graph=args.ddp_static_graph
) ) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers)) ]
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
accelerator = Accelerator( accelerator = Accelerator(