From 5ab00f9b49b5a3958bb0267fdb9236a96d503dbd Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:39:51 -0500 Subject: [PATCH 1/7] Update workflow tests with cleanup and documentation --- .github/workflows/tests.yml | 46 +++++++++++++++++++------------------ 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9ae67b0e..5a790d57 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,42 +1,44 @@ +name: Test with pytet -name: Python package - -on: [push] +on: + push: + branches: + - main + - dev + - sd3 + pull_request: + branches: + - main + - dev + - sd3 jobs: build: - runs-on: ${{ matrix.os }} strategy: matrix: os: [ubuntu-latest] - python-version: ["3.10"] + python-version: ["3.10"] # Python versions to test + pytorch-version: ["2.4.0"] # PyTorch versions to test steps: - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 + - uses: actions/setup-python@v5 with: - python-version: '3.x' + python-version: ${{ matrix.python-version }} + cache: 'pip' - - name: Install dependencies - run: python -m pip install --upgrade pip setuptools wheel - - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.x' - cache: 'pip' # caching pip dependencies + - name: Install and update pip, setuptools, wheel + run: | + # Setuptools, wheel for compiling some packages + python -m pip install --upgrade pip setuptools wheel - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install dadaptation==3.2 torch==2.4.0 torchvision==0.19.0 accelerate==0.33.0 + # Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch) + pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4 pip install -r requirements.txt - name: Test with pytest - run: | - pip install pytest - pytest + run: pytest # See pytest.ini for configuration From 63738ecb0758a02555392d2c283a83bba1c6f98e Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:48:30 -0500 Subject: [PATCH 2/7] Add tests documentation --- tests/README.md | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/README.md diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 00000000..19eeab0e --- /dev/null +++ b/tests/README.md @@ -0,0 +1,32 @@ +# Tests + +## Install + +``` +pip install pytest +``` + +## Usage + +``` +pytest +``` + +## Contribution + +Pytest is configured to run tests in this directory. It might be a good idea to add tests closer in the code, as well as doctests. + +Tests are functions starting with `test_` and files with the pattern `test_*.py`. + +``` +def test_x(): + assert 1 == 2, "Invalid test response" +``` + +## Resources + +- https://circleci.com/blog/testing-pytorch-model-with-pytest/ +- https://pytorch.org/docs/stable/testing.html +- https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests +- https://github.com/huggingface/pytorch-image-models/tree/main/tests +- https://github.com/pytorch/pytorch/tree/main/test From 2610e96e9e3d0605d5a16615efa26ae8935ed3aa Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:49:58 -0500 Subject: [PATCH 3/7] Pytest --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5a790d57..672a657b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Test with pytet +name: Test with pytest on: push: From 3e5d89c76c287872e20c4a967d36b51384285be8 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 2 Dec 2024 13:51:57 -0500 Subject: [PATCH 4/7] Add more resources --- tests/README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/README.md b/tests/README.md index 19eeab0e..9836da8b 100644 --- a/tests/README.md +++ b/tests/README.md @@ -25,8 +25,17 @@ def test_x(): ## Resources +### pytest + +- https://docs.pytest.org/en/stable/index.html +- https://docs.pytest.org/en/stable/how-to/assert.html +- https://docs.pytest.org/en/stable/how-to/doctest.html + +### PyTorch testing + - https://circleci.com/blog/testing-pytorch-model-with-pytest/ - https://pytorch.org/docs/stable/testing.html - https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests - https://github.com/huggingface/pytorch-image-models/tree/main/tests - https://github.com/pytorch/pytorch/tree/main/test + From 6bee18db4fbf62ebd2a1da88a5851c48f2e06c54 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 15:12:27 +0900 Subject: [PATCH 5/7] fix: resolve model corruption issue with pos_embed when using --enable_scaled_pos_embed --- README.md | 2 ++ library/sd3_models.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f0272519..6162359d 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,8 @@ The command to install PyTorch is as follows: ### Recent Updates +Dec 7, 2024: +- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/library/sd3_models.py b/library/sd3_models.py index 2f3c82ee..e4a93186 100644 --- a/library/sd3_models.py +++ b/library/sd3_models.py @@ -870,8 +870,10 @@ class MMDiT(nn.Module): self.use_scaled_pos_embed = use_scaled_pos_embed if self.use_scaled_pos_embed: - # remove pos_embed to free up memory up to 0.4 GB - self.pos_embed = None + # # remove pos_embed to free up memory up to 0.4 GB -> this causes error because pos_embed is not saved + # self.pos_embed = None + # move pos_embed to CPU to free up memory up to 0.4 GB + self.pos_embed = self.pos_embed.cpu() # remove duplicates and sort latent sizes in ascending order latent_sizes = list(set(latent_sizes)) From abff4b0ec7bb37b338924e38392593f2bea2b8d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 7 Dec 2024 16:12:46 +0800 Subject: [PATCH 6/7] Unify controlnet parameters name and change scripts name. (#1821) * Update sd3_train.py * add freeze block lr * Update train_util.py * update * Revert "add freeze block lr" This reverts commit 8b1653548f8f219e5be2cde96f65a8813cf9ea1f. # Conflicts: # library/train_util.py # sd3_train.py * use same control net model path * use controlnet_model_name_or_path --- flux_train_control_net.py | 2 +- library/flux_train_utils.py | 2 +- sdxl_train_control_net.py | 8 ++++---- train_controlnet.py => train_control_net.py | 0 4 files changed, 6 insertions(+), 6 deletions(-) rename train_controlnet.py => train_control_net.py (100%) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 5548fd99..9d36a41d 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -265,7 +265,7 @@ def train(args): # load controlnet controlnet_dtype = torch.float32 if args.deepspeed else weight_dtype controlnet = flux_utils.load_controlnet( - args.controlnet, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors + args.controlnet_model_name_or_path, is_schnell, controlnet_dtype, accelerator.device, args.disable_mmap_load_safetensors ) controlnet.train() diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2e2b48..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -564,7 +564,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") parser.add_argument( - "--controlnet", + "--controlnet_model_name_or_path", type=str, default=None, help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)" diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index 01387409..ffbf03ca 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -184,12 +184,12 @@ def train(args): # make control net logger.info("make ControlNet") - if args.controlnet_model_path: + if args.controlnet_model_name_or_path: with init_empty_weights(): control_net = SdxlControlNet() - logger.info(f"load ControlNet from {args.controlnet_model_path}") - filename = args.controlnet_model_path + logger.info(f"load ControlNet from {args.controlnet_model_name_or_path}") + filename = args.controlnet_model_name_or_path if os.path.splitext(filename)[1] == ".safetensors": state_dict = load_file(filename) else: @@ -675,7 +675,7 @@ def setup_parser() -> argparse.ArgumentParser: sdxl_train_util.add_sdxl_training_arguments(parser) parser.add_argument( - "--controlnet_model_path", + "--controlnet_model_name_or_path", type=str, default=None, help="controlnet model name or path / controlnetのモデル名またはパス", diff --git a/train_controlnet.py b/train_control_net.py similarity index 100% rename from train_controlnet.py rename to train_control_net.py From e425996a5953f0479384e70b6490e751c2d00b1f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Dec 2024 17:28:19 +0900 Subject: [PATCH 7/7] feat: unify ControlNet model name option and deprecate old training script --- README.md | 7 +++++++ train_controlnet.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 train_controlnet.py diff --git a/README.md b/README.md index 6162359d..67836ddf 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,13 @@ The command to install PyTorch is as follows: ### Recent Updates Dec 7, 2024: + +- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds! + + - Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`. Dec 3, 2024: diff --git a/train_controlnet.py b/train_controlnet.py new file mode 100644 index 00000000..365e35c8 --- /dev/null +++ b/train_controlnet.py @@ -0,0 +1,23 @@ +from library.utils import setup_logging + +setup_logging() +import logging + +logger = logging.getLogger(__name__) + + +from library import train_util +from train_control_net import setup_parser, train + +if __name__ == "__main__": + logger.warning( + "The module 'train_controlnet.py' is deprecated. Please use 'train_control_net.py' instead" + " / 'train_controlnet.py'は非推奨です。代わりに'train_control_net.py'を使用してください。" + ) + parser = setup_parser() + + args = parser.parse_args() + train_util.verify_command_line_training_args(args) + args = train_util.read_config_from_file(args, parser) + + train(args)