mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
Add pytest testing
This commit is contained in:
54
.github/workflows/tests.yml
vendored
Normal file
54
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
|
||||
name: Python package
|
||||
|
||||
on: [push]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.10", "3.11"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pip install pytest pytest-cov
|
||||
pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html
|
||||
|
||||
- name: Upload pytest test results
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pytest-results-${{ matrix.python-version }}
|
||||
path: junit/test-results-${{ matrix.python-version }}.xml
|
||||
# Use always() to always run this step to publish test results when there are test failures
|
||||
if: ${{ always() }}
|
||||
@@ -21,7 +21,7 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
Union
|
||||
)
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
import glob
|
||||
@@ -4598,7 +4598,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
accelerator.load_state(dirname)
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
|
||||
7
pytest.ini
Normal file
7
pytest.ini
Normal file
@@ -0,0 +1,7 @@
|
||||
[pytest]
|
||||
minversion = 6.0
|
||||
testpaths =
|
||||
tests
|
||||
filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
153
tests/test_optimizer.py
Normal file
153
tests/test_optimizer.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from unittest.mock import patch
|
||||
from library.train_util import get_optimizer
|
||||
from train_network import setup_parser
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
# Optimizer libraries
|
||||
import bitsandbytes as bnb
|
||||
from lion_pytorch import lion_pytorch
|
||||
import schedulefree
|
||||
|
||||
import dadaptation
|
||||
import dadaptation.experimental as dadapt_experimental
|
||||
|
||||
import prodigyopt
|
||||
import schedulefree as sf
|
||||
import transformers
|
||||
|
||||
|
||||
def test_default_get_optimizer():
|
||||
with patch("sys.argv", [""]):
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
params_t = torch.tensor([1.5, 1.5])
|
||||
|
||||
param = Parameter(params_t)
|
||||
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
|
||||
assert optimizer_name == "torch.optim.adamw.AdamW"
|
||||
assert optimizer_args == ""
|
||||
assert isinstance(optimizer, torch.optim.AdamW)
|
||||
|
||||
|
||||
def test_get_schedulefree_optimizer():
|
||||
with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]):
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
params_t = torch.tensor([1.5, 1.5])
|
||||
|
||||
param = Parameter(params_t)
|
||||
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
|
||||
assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree"
|
||||
assert optimizer_args == ""
|
||||
assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree)
|
||||
|
||||
|
||||
def test_all_supported_optimizers():
|
||||
optimizers = [
|
||||
{
|
||||
"name": "bitsandbytes.optim.adamw.AdamW8bit",
|
||||
"alias": "AdamW8bit",
|
||||
"instance": bnb.optim.AdamW8bit,
|
||||
},
|
||||
{
|
||||
"name": "lion_pytorch.lion_pytorch.Lion",
|
||||
"alias": "Lion",
|
||||
"instance": lion_pytorch.Lion,
|
||||
},
|
||||
{
|
||||
"name": "torch.optim.adamw.AdamW",
|
||||
"alias": "AdamW",
|
||||
"instance": torch.optim.AdamW,
|
||||
},
|
||||
{
|
||||
"name": "bitsandbytes.optim.lion.Lion8bit",
|
||||
"alias": "Lion8bit",
|
||||
"instance": bnb.optim.Lion8bit,
|
||||
},
|
||||
{
|
||||
"name": "bitsandbytes.optim.adamw.PagedAdamW8bit",
|
||||
"alias": "PagedAdamW8bit",
|
||||
"instance": bnb.optim.PagedAdamW8bit,
|
||||
},
|
||||
{
|
||||
"name": "bitsandbytes.optim.lion.PagedLion8bit",
|
||||
"alias": "PagedLion8bit",
|
||||
"instance": bnb.optim.PagedLion8bit,
|
||||
},
|
||||
{
|
||||
"name": "bitsandbytes.optim.adamw.PagedAdamW",
|
||||
"alias": "PagedAdamW",
|
||||
"instance": bnb.optim.PagedAdamW,
|
||||
},
|
||||
{
|
||||
"name": "bitsandbytes.optim.adamw.PagedAdamW32bit",
|
||||
"alias": "PagedAdamW32bit",
|
||||
"instance": bnb.optim.PagedAdamW32bit,
|
||||
},
|
||||
{"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD},
|
||||
{
|
||||
"name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint",
|
||||
"alias": "DAdaptAdamPreprint",
|
||||
"instance": dadapt_experimental.DAdaptAdamPreprint,
|
||||
},
|
||||
{
|
||||
"name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad",
|
||||
"alias": "DAdaptAdaGrad",
|
||||
"instance": dadaptation.DAdaptAdaGrad,
|
||||
},
|
||||
{
|
||||
"name": "dadaptation.dadapt_adan.DAdaptAdan",
|
||||
"alias": "DAdaptAdan",
|
||||
"instance": dadaptation.DAdaptAdan,
|
||||
},
|
||||
{
|
||||
"name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP",
|
||||
"alias": "DAdaptAdanIP",
|
||||
"instance": dadapt_experimental.DAdaptAdanIP,
|
||||
},
|
||||
{
|
||||
"name": "dadaptation.dadapt_lion.DAdaptLion",
|
||||
"alias": "DAdaptLion",
|
||||
"instance": dadaptation.DAdaptLion,
|
||||
},
|
||||
{
|
||||
"name": "dadaptation.dadapt_sgd.DAdaptSGD",
|
||||
"alias": "DAdaptSGD",
|
||||
"instance": dadaptation.DAdaptSGD,
|
||||
},
|
||||
{
|
||||
"name": "prodigyopt.prodigy.Prodigy",
|
||||
"alias": "Prodigy",
|
||||
"instance": prodigyopt.Prodigy,
|
||||
},
|
||||
{
|
||||
"name": "transformers.optimization.Adafactor",
|
||||
"alias": "Adafactor",
|
||||
"instance": transformers.optimization.Adafactor,
|
||||
},
|
||||
{
|
||||
"name": "schedulefree.adamw_schedulefree.AdamWScheduleFree",
|
||||
"alias": "AdamWScheduleFree",
|
||||
"instance": sf.AdamWScheduleFree,
|
||||
},
|
||||
{
|
||||
"name": "schedulefree.sgd_schedulefree.SGDScheduleFree",
|
||||
"alias": "SGDScheduleFree",
|
||||
"instance": sf.SGDScheduleFree,
|
||||
},
|
||||
]
|
||||
|
||||
for opt in optimizers:
|
||||
with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]):
|
||||
parser = setup_parser()
|
||||
args = parser.parse_args()
|
||||
params_t = torch.tensor([1.5, 1.5])
|
||||
|
||||
param = Parameter(params_t)
|
||||
optimizer_name, _, optimizer = get_optimizer(args, [param])
|
||||
assert optimizer_name == opt.get("name")
|
||||
|
||||
instance = opt.get("instance")
|
||||
assert instance is not None
|
||||
assert isinstance(optimizer, instance)
|
||||
Reference in New Issue
Block a user