Add pytest testing

This commit is contained in:
rockerBOO
2024-11-29 15:52:03 -05:00
parent 2a61fc0784
commit c7cadbc8c7
4 changed files with 216 additions and 2 deletions

View File

@@ -21,7 +21,7 @@ from typing import (
Optional,
Sequence,
Tuple,
Union,
Union
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
@@ -4598,7 +4598,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
accelerator.load_state(dirname)
def get_optimizer(args, trainable_params):
def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type