mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Add pytest testing
This commit is contained in:
54
.github/workflows/tests.yml
vendored
Normal file
54
.github/workflows/tests.yml
vendored
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
|
||||||
|
name: Python package
|
||||||
|
|
||||||
|
on: [push]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: ["3.10", "3.11"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: python -m pip install --upgrade pip setuptools wheel
|
||||||
|
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.x'
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
- name: Test with pytest
|
||||||
|
run: |
|
||||||
|
pip install pytest pytest-cov
|
||||||
|
pytest --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html
|
||||||
|
|
||||||
|
- name: Upload pytest test results
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: pytest-results-${{ matrix.python-version }}
|
||||||
|
path: junit/test-results-${{ matrix.python-version }}.xml
|
||||||
|
# Use always() to always run this step to publish test results when there are test failures
|
||||||
|
if: ${{ always() }}
|
||||||
@@ -21,7 +21,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||||
import glob
|
import glob
|
||||||
@@ -4598,7 +4598,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
|||||||
accelerator.load_state(dirname)
|
accelerator.load_state(dirname)
|
||||||
|
|
||||||
|
|
||||||
def get_optimizer(args, trainable_params):
|
def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
|
||||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||||
|
|
||||||
optimizer_type = args.optimizer_type
|
optimizer_type = args.optimizer_type
|
||||||
|
|||||||
7
pytest.ini
Normal file
7
pytest.ini
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
[pytest]
|
||||||
|
minversion = 6.0
|
||||||
|
testpaths =
|
||||||
|
tests
|
||||||
|
filterwarnings =
|
||||||
|
ignore::DeprecationWarning
|
||||||
|
ignore::UserWarning
|
||||||
153
tests/test_optimizer.py
Normal file
153
tests/test_optimizer.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
from library.train_util import get_optimizer
|
||||||
|
from train_network import setup_parser
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
# Optimizer libraries
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
from lion_pytorch import lion_pytorch
|
||||||
|
import schedulefree
|
||||||
|
|
||||||
|
import dadaptation
|
||||||
|
import dadaptation.experimental as dadapt_experimental
|
||||||
|
|
||||||
|
import prodigyopt
|
||||||
|
import schedulefree as sf
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_get_optimizer():
|
||||||
|
with patch("sys.argv", [""]):
|
||||||
|
parser = setup_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
params_t = torch.tensor([1.5, 1.5])
|
||||||
|
|
||||||
|
param = Parameter(params_t)
|
||||||
|
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
|
||||||
|
assert optimizer_name == "torch.optim.adamw.AdamW"
|
||||||
|
assert optimizer_args == ""
|
||||||
|
assert isinstance(optimizer, torch.optim.AdamW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_schedulefree_optimizer():
|
||||||
|
with patch("sys.argv", ["", "--optimizer_type", "AdamWScheduleFree"]):
|
||||||
|
parser = setup_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
params_t = torch.tensor([1.5, 1.5])
|
||||||
|
|
||||||
|
param = Parameter(params_t)
|
||||||
|
optimizer_name, optimizer_args, optimizer = get_optimizer(args, [param])
|
||||||
|
assert optimizer_name == "schedulefree.adamw_schedulefree.AdamWScheduleFree"
|
||||||
|
assert optimizer_args == ""
|
||||||
|
assert isinstance(optimizer, schedulefree.adamw_schedulefree.AdamWScheduleFree)
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_supported_optimizers():
|
||||||
|
optimizers = [
|
||||||
|
{
|
||||||
|
"name": "bitsandbytes.optim.adamw.AdamW8bit",
|
||||||
|
"alias": "AdamW8bit",
|
||||||
|
"instance": bnb.optim.AdamW8bit,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "lion_pytorch.lion_pytorch.Lion",
|
||||||
|
"alias": "Lion",
|
||||||
|
"instance": lion_pytorch.Lion,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "torch.optim.adamw.AdamW",
|
||||||
|
"alias": "AdamW",
|
||||||
|
"instance": torch.optim.AdamW,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bitsandbytes.optim.lion.Lion8bit",
|
||||||
|
"alias": "Lion8bit",
|
||||||
|
"instance": bnb.optim.Lion8bit,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bitsandbytes.optim.adamw.PagedAdamW8bit",
|
||||||
|
"alias": "PagedAdamW8bit",
|
||||||
|
"instance": bnb.optim.PagedAdamW8bit,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bitsandbytes.optim.lion.PagedLion8bit",
|
||||||
|
"alias": "PagedLion8bit",
|
||||||
|
"instance": bnb.optim.PagedLion8bit,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bitsandbytes.optim.adamw.PagedAdamW",
|
||||||
|
"alias": "PagedAdamW",
|
||||||
|
"instance": bnb.optim.PagedAdamW,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "bitsandbytes.optim.adamw.PagedAdamW32bit",
|
||||||
|
"alias": "PagedAdamW32bit",
|
||||||
|
"instance": bnb.optim.PagedAdamW32bit,
|
||||||
|
},
|
||||||
|
{"name": "torch.optim.sgd.SGD", "alias": "SGD", "instance": torch.optim.SGD},
|
||||||
|
{
|
||||||
|
"name": "dadaptation.experimental.dadapt_adam_preprint.DAdaptAdamPreprint",
|
||||||
|
"alias": "DAdaptAdamPreprint",
|
||||||
|
"instance": dadapt_experimental.DAdaptAdamPreprint,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "dadaptation.dadapt_adagrad.DAdaptAdaGrad",
|
||||||
|
"alias": "DAdaptAdaGrad",
|
||||||
|
"instance": dadaptation.DAdaptAdaGrad,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "dadaptation.dadapt_adan.DAdaptAdan",
|
||||||
|
"alias": "DAdaptAdan",
|
||||||
|
"instance": dadaptation.DAdaptAdan,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "dadaptation.experimental.dadapt_adan_ip.DAdaptAdanIP",
|
||||||
|
"alias": "DAdaptAdanIP",
|
||||||
|
"instance": dadapt_experimental.DAdaptAdanIP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "dadaptation.dadapt_lion.DAdaptLion",
|
||||||
|
"alias": "DAdaptLion",
|
||||||
|
"instance": dadaptation.DAdaptLion,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "dadaptation.dadapt_sgd.DAdaptSGD",
|
||||||
|
"alias": "DAdaptSGD",
|
||||||
|
"instance": dadaptation.DAdaptSGD,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "prodigyopt.prodigy.Prodigy",
|
||||||
|
"alias": "Prodigy",
|
||||||
|
"instance": prodigyopt.Prodigy,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "transformers.optimization.Adafactor",
|
||||||
|
"alias": "Adafactor",
|
||||||
|
"instance": transformers.optimization.Adafactor,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "schedulefree.adamw_schedulefree.AdamWScheduleFree",
|
||||||
|
"alias": "AdamWScheduleFree",
|
||||||
|
"instance": sf.AdamWScheduleFree,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "schedulefree.sgd_schedulefree.SGDScheduleFree",
|
||||||
|
"alias": "SGDScheduleFree",
|
||||||
|
"instance": sf.SGDScheduleFree,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for opt in optimizers:
|
||||||
|
with patch("sys.argv", ["", "--optimizer_type", opt.get("alias")]):
|
||||||
|
parser = setup_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
params_t = torch.tensor([1.5, 1.5])
|
||||||
|
|
||||||
|
param = Parameter(params_t)
|
||||||
|
optimizer_name, _, optimizer = get_optimizer(args, [param])
|
||||||
|
assert optimizer_name == opt.get("name")
|
||||||
|
|
||||||
|
instance = opt.get("instance")
|
||||||
|
assert instance is not None
|
||||||
|
assert isinstance(optimizer, instance)
|
||||||
Reference in New Issue
Block a user