mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix DDP SDXL training
This commit is contained in:
@@ -19,7 +19,7 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
@@ -2878,6 +2878,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
default=None,
|
||||
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト(分、Noneでaccelerateのデフォルト)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_as_bucket_view", action="store_true", help="enable gradient_as_bucket_view for DDP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--static_graph", action="store_true", help="enable static_graph for DDP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
@@ -3832,9 +3838,12 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
if args.wandb_api_key is not None:
|
||||
wandb.login(key=args.wandb_api_key)
|
||||
|
||||
|
||||
kwargs_handlers = (
|
||||
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
|
||||
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
|
||||
DistributedDataParallelKwargs(gradient_as_bucket_view=args.gradient_as_bucket_view, static_graph=args.static_graph)
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
|
||||
Reference in New Issue
Block a user