mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
support deepspeed
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
|
||||
from accelerate import DeepSpeedPlugin
|
||||
import gc
|
||||
import glob
|
||||
import math
|
||||
@@ -3124,6 +3125,47 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
|
||||
)
|
||||
|
||||
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
|
||||
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
|
||||
parser.add_argument(
|
||||
"--zero_stage",
|
||||
type=int, default=2,
|
||||
choices=[0, 1, 2, 3],
|
||||
help="Possible options are 0,1,2,3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer",
|
||||
type=str, default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_optimizer_nvme_path",
|
||||
type=str, default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_device",
|
||||
type=str, default=None,
|
||||
choices=[None, "cpu", "nvme"],
|
||||
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload_param_nvme_path",
|
||||
type=str, default=None,
|
||||
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_init_flag",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
|
||||
"Only applicable with ZeRO Stage-3."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--zero3_save_16bit_model",
|
||||
action="store_true",
|
||||
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
|
||||
)
|
||||
|
||||
def verify_training_args(args: argparse.Namespace):
|
||||
if args.v_parameterization and not args.v2:
|
||||
@@ -3912,6 +3954,17 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
else None,
|
||||
)
|
||||
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
|
||||
deepspeed_plugin = None
|
||||
if args.deepspeed:
|
||||
deepspeed_plugin = DeepSpeedPlugin(
|
||||
zero_stage=args.zero_stage,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
|
||||
offload_optimizer=args.offload_optimizer, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
|
||||
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
|
||||
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
|
||||
)
|
||||
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -3919,6 +3972,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
project_dir=logging_dir,
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
)
|
||||
return accelerator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user