support deepspeed

This commit is contained in:
BootsofLagrangian
2024-02-04 03:12:42 +09:00
parent cd19df49cd
commit dfe08f395f
5 changed files with 195 additions and 50 deletions

View File

@@ -20,6 +20,7 @@ from typing import (
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
from accelerate import DeepSpeedPlugin
import gc
import glob
import math
@@ -3124,6 +3125,47 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み"
)
# DeepSpeed Arguments. https://huggingface.co/docs/accelerate/usage_guides/deepspeed
parser.add_argument("--deepspeed", action="store_true", help="enable deepspeed training")
parser.add_argument(
"--zero_stage",
type=int, default=2,
choices=[0, 1, 2, 3],
help="Possible options are 0,1,2,3."
)
parser.add_argument(
"--offload_optimizer",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."
)
parser.add_argument(
"--offload_optimizer_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_device",
type=str, default=None,
choices=[None, "cpu", "nvme"],
help="Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--offload_param_nvme_path",
type=str, default=None,
help="Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."
)
parser.add_argument(
"--zero3_init_flag",
action="store_true",
help="Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
"Only applicable with ZeRO Stage-3."
)
parser.add_argument(
"--zero3_save_16bit_model",
action="store_true",
help="Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."
)
def verify_training_args(args: argparse.Namespace):
if args.v_parameterization and not args.v2:
@@ -3912,6 +3954,17 @@ def prepare_accelerator(args: argparse.Namespace):
else None,
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
deepspeed_plugin = None
if args.deepspeed:
deepspeed_plugin = DeepSpeedPlugin(
zero_stage=args.zero_stage,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_clipping=args.max_grad_norm,
offload_optimizer=args.offload_optimizer, offload_optimizer_nvme_path=args.offload_optimizer_nvme_path,
offload_param_device=args.offload_param_device, offload_param_nvme_path=args.offload_param_nvme_path,
zero3_init_flag=args.zero3_init_flag, zero3_save_16bit_model=args.zero3_save_16bit_model,
)
deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = args.train_batch_size
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -3919,6 +3972,7 @@ def prepare_accelerator(args: argparse.Namespace):
project_dir=logging_dir,
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
)
return accelerator