mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
torch 2.4
This commit is contained in:
@@ -33,6 +33,7 @@ from io import BytesIO
|
|||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
@@ -5048,7 +5049,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
kwargs_handlers = [
|
kwargs_handlers = [
|
||||||
InitProcessGroupKwargs(
|
InitProcessGroupKwargs(
|
||||||
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
backend = "gloo" if os.name == "nt" or not torch.cuda.is_available() else "nccl",
|
||||||
init_method="env://?use_libuv=False" if os.name == "nt" else None,
|
init_method="env://?use_libuv=False" if os.name == "nt" and Version(torch.__version__) >= Version("2.4.0") else None,
|
||||||
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
|
timeout=datetime.timedelta(minutes=args.ddp_timeout) if args.ddp_timeout else None
|
||||||
) if torch.cuda.device_count() > 1 else None,
|
) if torch.cuda.device_count() > 1 else None,
|
||||||
DistributedDataParallelKwargs(
|
DistributedDataParallelKwargs(
|
||||||
|
|||||||
Reference in New Issue
Block a user