Merge pull request #1812 from rockerBOO/tests

Add pytest testing
This commit is contained in:
Kohya S.
2024-12-02 21:38:43 +09:00
committed by GitHub
5 changed files with 210 additions and 5 deletions

42
.github/workflows/tests.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: Python package
on: [push]
jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
run: python -m pip install --upgrade pip setuptools wheel
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.x'
cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0
pip install -r requirements.txt
- name: Test with pytest
run: |
pip install pytest
pytest

View File

@@ -1,9 +1,11 @@
--- ---
# yamllint disable rule:line-length
name: Typos name: Typos
on: # yamllint disable-line rule:truthy on:
push: push:
branches:
- main
- dev
pull_request: pull_request:
types: types:
- opened - opened
@@ -18,4 +20,4 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: typos-action - name: typos-action
uses: crate-ci/typos@v1.24.3 uses: crate-ci/typos@v1.28.1

View File

@@ -21,7 +21,7 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Union, Union
) )
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob import glob
@@ -4607,7 +4607,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
accelerator.load_state(dirname) accelerator.load_state(dirname)
def get_optimizer(args, trainable_params): def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type optimizer_type = args.optimizer_type

8
pytest.ini Normal file
View File

@@ -0,0 +1,8 @@
[pytest]
minversion = 6.0
testpaths =
tests
filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning

153
tests/test_optimizer.py Normal file
View File

@@ -0,0 +1,153 @@
from unittest.mock import patch
from library.train_util import get_optimizer
from train_network import setup_parser
import torch
from torch.nn import Parameter
# Optimizer libraries
import bitsandbytes as bnb
from lion_pytorch import lion_pytorch
import schedulefree
import dadaptation
import dadaptation.experimental as dadapt_experimental
import prodigyopt
import schedulefree as sf
import transformers
def test_default_get_optimizer():
with patch("sys.argv", [""]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "torch.optim.adamw.AdamW"
assert optimizer_args == ""
assert isinstance(optimizer, torch.optim.AdamW)
def test_get_schedulefree_optimizer():
with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree"
assert optimizer_args == ""
assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree)
def test_all_supported_optimizers():
optimizers = [
{
"name": "bitsandbytes.optim.adamw.AdamW8bit",
"alias": "AdamW8bit",
"instance": bnb.optim.AdamW8bit,
},
{
"name": "lion_pytorch.lion_pytorch.Lion",
"alias": "Lion",
"instance": lion_pytorch.Lion,
},
{
"name": "torch.optim.adamw.AdamW",
"alias": "AdamW",
"instance": torch.optim.AdamW,
},
{
"name": "bitsandbytes.optim.lion.Lion8bit",
"alias": "Lion8bit",
"instance": bnb.optim.Lion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW8bit",
"alias": "PagedAdamW8bit",
"instance": bnb.optim.PagedAdamW8bit,
},
{
"name": "bitsandbytes.optim.lion.PagedLion8bit",
"alias": "PagedLion8bit",
"instance": bnb.optim.PagedLion8bit,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW",
"alias": "PagedAdamW",
"instance": bnb.optim.PagedAdamW,
},
{
"name": "bitsandbytes.optim.adamw.PagedAdamW32bit",
"alias": "PagedAdamW32bit",
"instance": bnb.optim.PagedAdamW32bit,
},
{"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD},
{
"name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint",
"alias": "DAdaptAdamPreprint",
"instance": dadapt_experimental.DAdaptAdamPreprint,
},
{
"name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad",
"alias": "DAdaptAdaGrad",
"instance": dadaptation.DAdaptAdaGrad,
},
{
"name": "dadaptation.dadapt_adan.DAdaptAdan",
"alias": "DAdaptAdan",
"instance": dadaptation.DAdaptAdan,
},
{
"name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP",
"alias": "DAdaptAdanIP",
"instance": dadapt_experimental.DAdaptAdanIP,
},
{
"name": "dadaptation.dadapt_lion.DAdaptLion",
"alias": "DAdaptLion",
"instance": dadaptation.DAdaptLion,
},
{
"name": "dadaptation.dadapt_sgd.DAdaptSGD",
"alias": "DAdaptSGD",
"instance": dadaptation.DAdaptSGD,
},
{
"name": "prodigyopt.prodigy.Prodigy",
"alias": "Prodigy",
"instance": prodigyopt.Prodigy,
},
{
"name": "transformers.optimization.Adafactor",
"alias": "Adafactor",
"instance": transformers.optimization.Adafactor,
},
{
"name": "schedulefree.adamw_schedulefree.AdamWScheduleFree",
"alias": "AdamWScheduleFree",
"instance": sf.AdamWScheduleFree,
},
{
"name": "schedulefree.sgd_schedulefree.SGDScheduleFree",
"alias": "SGDScheduleFree",
"instance": sf.SGDScheduleFree,
},
]
for opt in optimizers:
with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]):
parser = setup_parser()
args = parser.parse_args()
params_t = torch.tensor([1.5, 1.5])
param = Parameter(params_t)
optimizer_name, _, optimizer = get_optimizer(args, [param])
assert optimizer_name == opt.get("name")
instance = opt.get("instance")
assert instance is not None
assert isinstance(optimizer, instance)