From 2258a1b753a321ed25b9ae1b7f2ceb1b24ae0736 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 31 Mar 2024 15:50:35 +0900 Subject: [PATCH] add save/load hook to remove U-Net/TEs from state --- README.md | 12 ++++++++++-- train_network.py | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 71614b9c..87cdf0b7 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,10 @@ pip install --use-pep517 --upgrade -r requirements.txt Once the commands have completed successfully you should be ready to use the new version. +### Upgrade PyTorch + +If you want to upgrade PyTorch, you can upgrade it with `pip install` command in [Windows Installation](#windows-installation) section. `xformers` is also required to be upgraded when PyTorch is upgraded. + ## Credits The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work! @@ -137,12 +141,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. - - Especially `imagesize` is newly added, so if you cannot update immediately, please install with `pip install imagesize==1.4.1`. + - Especially `imagesize` is newly added, so if you cannot update the libraries immediately, please install with `pip install imagesize==1.4.1` separately. - `bitsandbytes==0.43.0`, `prodigyopt==1.0`, `lion-pytorch==0.0.6` are included in the requirements.txt. + - Also, the PyTorch version is updated to 2.1.2 (PyTorch does not need to be updated immediately). In the upgrade procedure, PyTorch is not updated, so please manually install or update torch, torchvision, xformers if necessary (see [Upgrade PyTorch](#upgrade-pytorch)). - Colab seems to stop with log output. Try specifying `--console_log_simple` option in the training script to disable rich logging. - The `.toml` file for the dataset config is now read in UTF-8 encoding. PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Thanks to Horizon1704! - Fixed a bug that the last subset settings are applied to all images when multiple subsets of regularization images are specified in the dataset settings. The settings for each subset are correctly applied to each image. PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) Thanks to feffy380! - `train_network.py` and `sdxl_train_network.py` are modified to record some dataset settings in the metadata of the trained model (`caption_prefix`, `caption_suffix`, `keep_tokens_separator`, `secondary_separator`, `enable_wildcard`). +- Fixed a bug that U-Net and Text Encoders are included in the state in `train_network.py` and `sdxl_train_network.py`. The saving and loading of the state are faster, the file size is smaller, and the memory usage when loading is reduced. - DeepSpeed is supported. PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) and [#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) Thanks to BootsofLagrangian! See PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) for details. - The masked loss is supported in each training script. PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) See [Masked loss](#masked-loss) for details. - Some features are added to the dataset subset settings. @@ -171,12 +177,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - The [English version of the dataset settings documentation](./docs/config_README-en.md) is added. PR [#1175](https://github.com/kohya-ss/sd-scripts/pull/1175) Thanks to darkstorm2150! - 依存ライブラリが更新されました。[アップグレード](./README-ja.md#アップグレード) を参照しライブラリを更新してください。 - - 特に `imagesize` が新しく追加されていますので、すぐに更新ができない場合は `pip install imagesize==1.4.1` でインストールしてください。 + - 特に `imagesize` が新しく追加されていますので、すぐにライブラリの更新ができない場合は `pip install imagesize==1.4.1` で個別にインストールしてください。 - `bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` が requirements.txt に含まれるようになりました。 + - また PyTorch のバージョンを 2.1.2 に更新しました。PyTorch はすぐに更新する必要はありません。更新時は、アップグレードの手順では PyTorch が更新されませんので、torch、torchvision、xformers を手動でインストールしてください。 - Colab での動作時、ログ出力で停止してしまうようです。学習スクリプトに `--console_log_simple` オプションを指定し、rich のロギングを無効してお試しください。 - データセット設定の `.toml` ファイルが UTF-8 encoding で読み込まれるようになりました。PR [#1167](https://github.com/kohya-ss/sd-scripts/pull/1167) Horizon1704 氏に感謝します。 - データセット設定で、正則化画像のサブセットを複数指定した時、最後のサブセットの各種設定がすべてのサブセットの画像に適用される不具合が修正されました。それぞれのサブセットの設定が、それぞれの画像に正しく適用されます。PR [#1205](https://github.com/kohya-ss/sd-scripts/pull/1205) feffy380 氏に感謝します。 - `train_network.py` および `sdxl_train_network.py` で、学習したモデルのメタデータに一部のデータセット設定が記録されるよう修正しました(`caption_prefix`、`caption_suffix`、`keep_tokens_separator`、`secondary_separator`、`enable_wildcard`)。 +- `train_network.py` および `sdxl_train_network.py` で、state に U-Net および Text Encoder が含まれる不具合を修正しました。state の保存、読み込みが高速化され、ファイルサイズも小さくなり、また読み込み時のメモリ使用量も削減されます。 - DeepSpeed がサポートされました。PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) 、[#1139](https://github.com/kohya-ss/sd-scripts/pull/1139) BootsofLagrangian 氏に感謝します。詳細は PR [#1101](https://github.com/kohya-ss/sd-scripts/pull/1101) をご覧ください。 - 各学習スクリプトでマスクロスをサポートしました。PR [#1207](https://github.com/kohya-ss/sd-scripts/pull/1207) 詳細は [Masked loss](#masked-loss) をご覧ください。 - データセットのサブセット設定にいくつかの機能を追加しました。 diff --git a/train_network.py b/train_network.py index ed569aea..8fe98f12 100644 --- a/train_network.py +++ b/train_network.py @@ -471,6 +471,31 @@ class NetworkTrainer: if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) + # before resuming make hook for saving/loading to save/load the network weights only + def save_model_hook(models, weights, output_dir): + # pop weights of other models than network to save only network weights + if accelerator.is_main_process: + remove_indices = [] + for i,model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + weights.pop(i) + # print(f"save model hook: {len(weights)} weights will be saved") + + def load_model_hook(models, input_dir): + # remove models except network + remove_indices = [] + for i, model in enumerate(models): + if not isinstance(model, type(accelerator.unwrap_model(network))): + remove_indices.append(i) + for i in reversed(remove_indices): + models.pop(i) + # print(f"load model hook: {len(models)} models will be loaded") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # resumeする train_util.resume_from_local_or_hf_if_specified(accelerator, args)