Merge pull request #1000 from Isotr0py/dev

Fix multi-gpu SDXL training
This commit is contained in:
Kohya S
2023-12-13 20:52:11 +09:00
committed by GitHub
2 changed files with 14 additions and 2 deletions

View File

@@ -19,7 +19,7 @@ from typing import (
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
@@ -2899,6 +2899,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト分、Noneでaccelerateのデフォルト",
)
parser.add_argument(
"--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP",
)
parser.add_argument(
"--static_graph", action="store_true", help="enable static_graph for DDP",
)
parser.add_argument(
"--clip_skip",
type=int,
@@ -3860,9 +3866,12 @@ def prepare_accelerator(args: argparse.Namespace):
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)
kwargs_handlers = (
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph)
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,