mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
Accelerate dataloader_config to non_blocking if pin_memory is enabled
This commit is contained in:
@@ -23,7 +23,7 @@ from typing import (
|
||||
Tuple,
|
||||
Union
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
@@ -5299,6 +5299,8 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
|
||||
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)
|
||||
|
||||
dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -5307,6 +5309,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
||||
kwargs_handlers=kwargs_handlers,
|
||||
dynamo_backend=dynamo_backend,
|
||||
deepspeed_plugin=deepspeed_plugin,
|
||||
dataloader_config=dataloader_config
|
||||
)
|
||||
print("accelerator device:", accelerator.device)
|
||||
return accelerator
|
||||
|
||||
Reference in New Issue
Block a user