diff --git a/library/train_util.py b/library/train_util.py index 1e6fe3b8..9711dd56 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -23,7 +23,7 @@ from typing import ( Tuple, Union ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration import glob import math import os @@ -5299,6 +5299,8 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -5307,6 +5309,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, + dataloader_config=dataloader_config ) print("accelerator device:", accelerator.device) return accelerator