mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix: distributed training in windows
This commit is contained in:
@@ -5045,17 +5045,18 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
if args.torch_compile:
|
if args.torch_compile:
|
||||||
dynamo_backend = args.dynamo_backend
|
dynamo_backend = args.dynamo_backend
|
||||||
|
|
||||||
kwargs_handlers = (
|
kwargs_handlers = [
|
||||||
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
InitProcessGroupKwargs(
|
||||||
(
|
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||||
DistributedDataParallelKwargs(
|
init_method="env://?use_libuv=False" if os.name == "nt" else None,
|
||||||
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph
|
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
|
||||||
)
|
) if torch.cuda.device_count() > 1 else None,
|
||||||
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
|
DistributedDataParallelKwargs(
|
||||||
else None
|
gradient_as_bucket_view=args.ddp_gradient_as_bucket_view,
|
||||||
),
|
static_graph=args.ddp_static_graph
|
||||||
)
|
) if args.ddp_gradient_as_bucket_view or args.ddp_static_graph else None
|
||||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
]
|
||||||
|
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
|
|||||||
Reference in New Issue
Block a user