Merge pull request #1812 from rockerBOO/tests

Add pytest testing
This commit is contained in:
Kohya S.
2024-12-02 21:38:43 +09:00
committed by GitHub
5 changed files with 210 additions and 5 deletions

View File

@@ -21,7 +21,7 @@ from typing import (
Optional,
Sequence,
Tuple,
Union,
Union
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
@@ -4607,7 +4607,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
accelerator.load_state(dirname)
def get_optimizer(args, trainable_params):
def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type