mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'kohya-ss:main' into support-multi-gpu
This commit is contained in:
21
.github/workflows/typos.yml
vendored
Normal file
21
.github/workflows/typos.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
---
|
||||
# yamllint disable rule:line-length
|
||||
name: Typos
|
||||
|
||||
on: # yamllint disable-line rule:truthy
|
||||
push:
|
||||
pull_request:
|
||||
types:
|
||||
- opened
|
||||
- synchronize
|
||||
- reopened
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.13.10
|
||||
@@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。(bf1
|
||||
cd sd-scripts
|
||||
git pull
|
||||
.\venv\Scripts\activate
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install --use-pep517 --upgrade -r requirements.txt
|
||||
```
|
||||
|
||||
コマンドが成功すれば新しいバージョンが使用できます。
|
||||
|
||||
139
README.md
139
README.md
@@ -1,51 +1,7 @@
|
||||
This repository contains training, generation and utility scripts for Stable Diffusion.
|
||||
|
||||
## Updates
|
||||
|
||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
|
||||
|
||||
Note: The LoRA models for SD 2.x is not supported too in Web UI.
|
||||
|
||||
- 6 Feb. 2023, 2023/2/6:
|
||||
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
|
||||
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
|
||||
- Any value greater than or equal to 1 can be specified; 64 is highly recommended and a value divisible by 8 is recommended.
|
||||
- If less than 64 is specified, padding will occur within U-Net. The result is unknown.
|
||||
- If you specify a value that is not divisible by 8, it will be truncated to divisible by 8 inside VAE, because the size of the latent is 1/8 of the image size.
|
||||
- If ``--bucket_no_upscale`` option is specified, images smaller than the bucket size will be processed without upscaling.
|
||||
- Internally, a bucket smaller than the image size is created (for example, if the image is 300x300 and ``bucket_reso_steps=64``, the bucket is 256x256). The image will be trimmed.
|
||||
- Implementation of [#130](https://github.com/kohya-ss/sd-scripts/issues/130).
|
||||
- Images with an area larger than the maximum size specified by ``--resolution`` are downsampled to the max bucket size.
|
||||
- Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
|
||||
- ``--random_crop`` now also works with buckets enabled.
|
||||
- Instead of always cropping the center of the image, the image is shifted left, right, up, and down to be used as the training data. This is expected to train to the edges of the image.
|
||||
- Implementation of discussion [#34](https://github.com/kohya-ss/sd-scripts/discussions/34).
|
||||
|
||||
- ``--bucket_reso_steps``および``--bucket_no_upscale``オプションを、学習スクリプトおよび``prepare_buckets_latents.py``に追加しました。
|
||||
- ``--bucket_reso_steps``オプションでは、bucketの解像度の単位を指定できます。デフォルトは64で、今までと同じ動作です。
|
||||
- 1以上の任意の値を指定できます。基本的には64を推奨します。64以外の値では、8で割り切れる値を推奨します。
|
||||
- 64未満を指定するとU-Netの内部でpaddingが発生します。どのような結果になるかは未知数です。
|
||||
- 8で割り切れない値を指定すると余りはVAE内部で切り捨てられます。
|
||||
- ``--bucket_no_upscale``オプションを指定すると、bucketサイズよりも小さい画像は拡大せずそのまま処理します。
|
||||
- 内部的には画像サイズ以下のサイズのbucketを作成します(たとえば画像が300x300で``bucket_reso_steps=64``の場合、256x256のbucket)。余りは都度trimmingされます。
|
||||
- [#130](https://github.com/kohya-ss/sd-scripts/issues/130) を実装したものです。
|
||||
- ``--resolution``で指定した最大サイズよりも面積が大きい画像は、最大サイズと同じ面積になるようアスペクト比を維持したまま縮小され、そのサイズを元にbucketが作られます。
|
||||
- これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
|
||||
- たとえば繰り返し回数10で、あるbucketに1枚しか画像がなく、バッチサイズが10以上のとき、今まではepoch内で、同一画像を10枚含むバッチが1回だけ使用されていました。
|
||||
- 機能追加後はepoch内にサイズ1のバッチが10回、使用されます。
|
||||
- ``--random_crop``がbucketを有効にした場合にも機能するようになりました。
|
||||
- 常に画像の中央を切り取るのではなく、左右、上下にずらして教師データにします。これにより画像端まで学習されることが期待されます。
|
||||
- discussionの[#34](https://github.com/kohya-ss/sd-scripts/discussions/34)を実装したものです。
|
||||
|
||||
|
||||
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
|
||||
|
||||
注:SD2.x用のLoRAモデルはサポートされないようです。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
##
|
||||
[__Change History__](#change-history) is moved to the bottom of the page.
|
||||
更新履歴は[ページ末尾](#change-history)に移しました。
|
||||
|
||||
[日本語版README](./README-ja.md)
|
||||
|
||||
@@ -54,10 +10,13 @@ For easier use (GUI and PowerShell scripts etc...), please visit [the repository
|
||||
This repository contains the scripts for:
|
||||
|
||||
* DreamBooth training, including U-Net and Text Encoder
|
||||
* fine-tuning (native training), including U-Net and Text Encoder
|
||||
* Fine-tuning (native training), including U-Net and Text Encoder
|
||||
* LoRA training
|
||||
* image generation
|
||||
* model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||
* Texutl Inversion training
|
||||
* Image generation
|
||||
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
|
||||
|
||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
|
||||
|
||||
## About requirements.txt
|
||||
|
||||
@@ -144,7 +103,7 @@ When a new release comes out you can upgrade your repo with the following comman
|
||||
cd sd-scripts
|
||||
git pull
|
||||
.\venv\Scripts\activate
|
||||
pip install --upgrade -r requirements.txt
|
||||
pip install --use-pep517 --upgrade -r requirements.txt
|
||||
```
|
||||
|
||||
Once the commands have completed successfully you should be ready to use the new version.
|
||||
@@ -162,3 +121,83 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
|
||||
|
||||
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
|
||||
|
||||
## Change History
|
||||
|
||||
- 11 Feb. 2023, 2023/2/11:
|
||||
- ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage.
|
||||
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
|
||||
- Batch size can be large (like 64 or 128).
|
||||
- ``train_textual_inversion.py`` now supports multiple init words.
|
||||
- Following feature is reverted to be the same as before. Sorry for confusion:
|
||||
> Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
|
||||
|
||||
- ``lora_interrogator.py`` を ``network``フォルダに追加しました。使用法は ``python networks\lora_interrogator.py -h`` でご確認ください。
|
||||
- このスクリプトは、起動promptがわからないLoRAについて、LoRA適用前後のText Encoderの出力を比較することで、どのtokenの出力が変化しているかを調べます。運が良ければ起動用の単語が分かります。キャプション付きで学習されたLoRAは影響が広範囲に及ぶため、調査は難しいようです。
|
||||
- バッチサイズはわりと大きくできます(64や128など)。
|
||||
- ``train_textual_inversion.py`` で複数のinit_word指定が可能になりました。
|
||||
- 次の機能を削除し元に戻しました。混乱を招き申し訳ありません。
|
||||
> これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
|
||||
|
||||
- 10 Feb. 2023, 2023/2/10:
|
||||
- Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade.
|
||||
- ``resize_lora.py`` keeps the metadata of the model. ``dimension is resized from ...`` is added to the top of ``ss_training_comment``.
|
||||
- ``merge_lora.py`` supports models with different ``alpha``s. If there is a problem, old version is ``merge_lora_old.py``.
|
||||
- ``svd_merge_lora.py`` is added. This script merges LoRA models with any rank (dim) and alpha, and approximate a new LoRA with svd for a specified rank (dim).
|
||||
- Note: merging scripts erase the metadata currently.
|
||||
- ``resize_images_to_resolution.py`` supports multibyte characters in filenames.
|
||||
- pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。
|
||||
- ``resize_lora.py``がメタデータを保持するようになりました。 ``dimension is resized from ...`` という文字列が ``ss_training_comment`` の先頭に追加されます。
|
||||
- ``merge_lora.py``がalphaが異なるモデルをサポートしました。 何か問題がありましたら旧バージョン ``merge_lora_old.py`` をお使いください。
|
||||
- ``svd_merge_lora.py`` を追加しました。 複数の任意のdim (rank)、alphaのLoRAモデルをマージし、svdで任意dim(rank)のLoRAで近似します。
|
||||
- 注:マージ系のスクリプトは現時点ではメタデータを消去しますのでご注意ください。
|
||||
- ``resize_images_to_resolution.py``が日本語ファイル名をサポートしました。
|
||||
|
||||
- 9 Feb. 2023, 2023/2/9:
|
||||
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
|
||||
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
|
||||
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
|
||||
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
|
||||
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
|
||||
- Typo check is added. Thanks to shirayu!
|
||||
- キャプションのドロップアウトを``train_db.py``、``fine_tune.py``、``train_network.py``の各スクリプトに追加しました。forestsource氏に感謝します。
|
||||
- ``--caption_dropout_rate``オプションでキャプションのドロップアウト率を指定します(0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトされた場合、画像は空のキャプションで学習されます。デフォルトは 0 (ドロップアウトなし)です。
|
||||
- ``--caption_dropout_every_n_epochs`` オプションで何エポックごとにキャプションを完全にドロップアウトするか指定します。たとえば``3``を指定すると、エポック3、6、9……で、すべての画像がキャプションなしで学習されます。デフォルトは None (ドロップアウトなし)です。
|
||||
- ``--caption_tag_dropout_rate`` オプションで各タグ(カンマ区切りの各部分)のドロップアウト率を指定します(0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトが起きるとそのタグはそのときだけキャプションから取り除かれて学習されます。``--keep_tokens`` オプションを指定していると、シャッフルされない部分のタグはドロップアウトされません。デフォルトは 0 (ドロップアウトなし)です。
|
||||
- 画像の一括縮小スクリプトを追加しました。ドキュメントは [こちら](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) です。bmaltais氏に感謝します。
|
||||
- 誤字チェッカが追加されました。shirayu氏に感謝します。
|
||||
|
||||
- 6 Feb. 2023, 2023/2/6:
|
||||
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
|
||||
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
|
||||
- Any value greater than or equal to 1 can be specified; 64 is highly recommended and a value divisible by 8 is recommended.
|
||||
- If less than 64 is specified, padding will occur within U-Net. The result is unknown.
|
||||
- If you specify a value that is not divisible by 8, it will be truncated to divisible by 8 inside VAE, because the size of the latent is 1/8 of the image size.
|
||||
- If ``--bucket_no_upscale`` option is specified, images smaller than the bucket size will be processed without upscaling.
|
||||
- Internally, a bucket smaller than the image size is created (for example, if the image is 300x300 and ``bucket_reso_steps=64``, the bucket is 256x256). The image will be trimmed.
|
||||
- Implementation of [#130](https://github.com/kohya-ss/sd-scripts/issues/130).
|
||||
- Images with an area larger than the maximum size specified by ``--resolution`` are downsampled to the max bucket size.
|
||||
- Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
|
||||
- ``--random_crop`` now also works with buckets enabled.
|
||||
- Instead of always cropping the center of the image, the image is shifted left, right, up, and down to be used as the training data. This is expected to train to the edges of the image.
|
||||
- Implementation of discussion [#34](https://github.com/kohya-ss/sd-scripts/discussions/34).
|
||||
|
||||
- ``--bucket_reso_steps``および``--bucket_no_upscale``オプションを、学習スクリプトおよび``prepare_buckets_latents.py``に追加しました。
|
||||
- ``--bucket_reso_steps``オプションでは、bucketの解像度の単位を指定できます。デフォルトは64で、今までと同じ動作です。
|
||||
- 1以上の任意の値を指定できます。基本的には64を推奨します。64以外の値では、8で割り切れる値を推奨します。
|
||||
- 64未満を指定するとU-Netの内部でpaddingが発生します。どのような結果になるかは未知数です。
|
||||
- 8で割り切れない値を指定すると余りはVAE内部で切り捨てられます。
|
||||
- ``--bucket_no_upscale``オプションを指定すると、bucketサイズよりも小さい画像は拡大せずそのまま処理します。
|
||||
- 内部的には画像サイズ以下のサイズのbucketを作成します(たとえば画像が300x300で``bucket_reso_steps=64``の場合、256x256のbucket)。余りは都度trimmingされます。
|
||||
- [#130](https://github.com/kohya-ss/sd-scripts/issues/130) を実装したものです。
|
||||
- ``--resolution``で指定した最大サイズよりも面積が大きい画像は、最大サイズと同じ面積になるようアスペクト比を維持したまま縮小され、そのサイズを元にbucketが作られます。
|
||||
- これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
|
||||
- たとえば繰り返し回数10で、あるbucketに1枚しか画像がなく、バッチサイズが10以上のとき、今まではepoch内で、同一画像を10枚含むバッチが1回だけ使用されていました。
|
||||
- 機能追加後はepoch内にサイズ1のバッチが10回、使用されます。
|
||||
- ``--random_crop``がbucketを有効にした場合にも機能するようになりました。
|
||||
- 常に画像の中央を切り取るのではなく、左右、上下にずらして教師データにします。これにより画像端まで学習されることが期待されます。
|
||||
- discussionの[#34](https://github.com/kohya-ss/sd-scripts/discussions/34)を実装したものです。
|
||||
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
15
_typos.toml
Normal file
15
_typos.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
# Files for typos
|
||||
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
|
||||
|
||||
[default.extend-identifiers]
|
||||
|
||||
[default.extend-words]
|
||||
NIN="NIN"
|
||||
parms="parms"
|
||||
nin="nin"
|
||||
extention="extention" # Intentionally left
|
||||
nd="nd"
|
||||
|
||||
|
||||
[files]
|
||||
extend-exclude = ["_typos.toml"]
|
||||
@@ -36,6 +36,10 @@ def train(args):
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
|
||||
# 学習データのdropout率を設定する
|
||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
@@ -226,6 +230,8 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset.set_current_epoch(epoch + 1)
|
||||
|
||||
for m in training_models:
|
||||
m.train()
|
||||
|
||||
@@ -332,7 +338,7 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, False, True)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
|
||||
|
||||
@@ -113,7 +113,7 @@ class BucketManager():
|
||||
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
|
||||
self.predefined_resos = resos.copy()
|
||||
self.predefined_resos_set = set(resos)
|
||||
self.predifined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
|
||||
|
||||
def add_if_new_reso(self, reso):
|
||||
if reso not in self.reso_to_id:
|
||||
@@ -135,7 +135,7 @@ class BucketManager():
|
||||
if reso in self.predefined_resos_set:
|
||||
pass
|
||||
else:
|
||||
ar_errors = self.predifined_aspect_ratios - aspect_ratio
|
||||
ar_errors = self.predefined_aspect_ratios - aspect_ratio
|
||||
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
|
||||
reso = self.predefined_resos[predefined_bucket_id]
|
||||
|
||||
@@ -223,6 +223,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
|
||||
self.dropout_rate: float = 0
|
||||
self.dropout_every_n_epochs: int = None
|
||||
self.tag_dropout_rate: float = 0
|
||||
|
||||
# augmentation
|
||||
flip_p = 0.5 if flip_aug else 0.0
|
||||
if color_aug:
|
||||
@@ -247,6 +252,15 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.replacements = {}
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
self.current_epoch = epoch
|
||||
|
||||
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
|
||||
# コンストラクタで渡さないのはTextual Inversionで意識したくないから(ということにしておく)
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropout_every_n_epochs = dropout_every_n_epochs
|
||||
self.tag_dropout_rate = tag_dropout_rate
|
||||
|
||||
def set_tag_frequency(self, dir_name, captions):
|
||||
frequency_for_dir = self.tag_frequency.get(dir_name, {})
|
||||
self.tag_frequency[dir_name] = frequency_for_dir
|
||||
@@ -264,18 +278,43 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
tokens = caption.strip().split(",")
|
||||
# dropoutの決定:tag dropがこのメソッド内にあるのでここで行うのが良い
|
||||
is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
|
||||
is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
|
||||
|
||||
if is_drop_out:
|
||||
caption = ""
|
||||
else:
|
||||
if self.shuffle_caption or self.tag_dropout_rate > 0:
|
||||
def dropout_tags(tokens):
|
||||
if self.tag_dropout_rate <= 0:
|
||||
return tokens
|
||||
l = []
|
||||
for token in tokens:
|
||||
if random.random() >= self.tag_dropout_rate:
|
||||
l.append(token)
|
||||
return l
|
||||
|
||||
tokens = [t.strip() for t in caption.strip().split(",")]
|
||||
if self.shuffle_keep_tokens is None:
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
else:
|
||||
if len(tokens) > self.shuffle_keep_tokens:
|
||||
keep_tokens = tokens[:self.shuffle_keep_tokens]
|
||||
tokens = tokens[self.shuffle_keep_tokens:]
|
||||
random.shuffle(tokens)
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ",".join(tokens).strip()
|
||||
|
||||
if self.shuffle_caption:
|
||||
random.shuffle(tokens)
|
||||
|
||||
tokens = dropout_tags(tokens)
|
||||
|
||||
tokens = keep_tokens + tokens
|
||||
caption = ", ".join(tokens)
|
||||
|
||||
# textual inversion対応
|
||||
for str_from, str_to in self.replacements.items():
|
||||
if str_from == "":
|
||||
# replace all
|
||||
@@ -393,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
|
||||
self.buckets_indices: List(BucketBatchIndex) = []
|
||||
for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
|
||||
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# そのためバッチサイズを画像種類までに制限する
|
||||
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# TODO 正則化画像をepochまたがりで利用する仕組み
|
||||
num_of_image_types = len(set(bucket))
|
||||
bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
batch_count = int(math.ceil(len(bucket) / self.batch_size))
|
||||
for batch_index in range(batch_count):
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
|
||||
|
||||
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
|
||||
# 学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
|
||||
#
|
||||
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
|
||||
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
|
||||
# # そのためバッチサイズを画像種類までに制限する
|
||||
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない?
|
||||
# # TO DO 正則化画像をepochまたがりで利用する仕組み
|
||||
# num_of_image_types = len(set(bucket))
|
||||
# bucket_batch_size = min(self.batch_size, num_of_image_types)
|
||||
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
|
||||
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
|
||||
# for batch_index in range(batch_count):
|
||||
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
|
||||
# ↑ここまで
|
||||
|
||||
self.shuffle_buckets()
|
||||
self._length = len(self.buckets_indices)
|
||||
@@ -809,6 +856,7 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = len(metadata) * dataset_repeats
|
||||
self.num_reg_images = 0
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
|
||||
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
|
||||
|
||||
@@ -907,6 +955,8 @@ class FineTuningDataset(BaseDataset):
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("Escape for exit. / Escキーで中断、終了します")
|
||||
|
||||
train_dataset.set_current_epoch(1)
|
||||
k = 0
|
||||
for i, example in enumerate(train_dataset):
|
||||
if example['latents'] is not None:
|
||||
@@ -1377,7 +1427,7 @@ def verify_training_args(args: argparse.Namespace):
|
||||
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
|
||||
|
||||
|
||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool):
|
||||
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
|
||||
# dataset common
|
||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||
parser.add_argument("--shuffle_caption", action="store_true",
|
||||
@@ -1408,6 +1458,16 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
|
||||
parser.add_argument("--bucket_no_upscale", action="store_true",
|
||||
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
|
||||
|
||||
if support_caption_dropout:
|
||||
# Textual Inversion はcaptionのdropoutをsupportしない
|
||||
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
|
||||
parser.add_argument("--caption_dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
|
||||
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
|
||||
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
|
||||
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
|
||||
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
|
||||
|
||||
if support_dreambooth:
|
||||
# DreamBooth dataset
|
||||
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import List
|
||||
import torch
|
||||
|
||||
from library import train_util
|
||||
@@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
self.alpha = alpha
|
||||
|
||||
# create module instances
|
||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
|
||||
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
|
||||
loras = []
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
|
||||
122
networks/lora_interrogator.py
Normal file
122
networks/lora_interrogator.py
Normal file
@@ -0,0 +1,122 @@
|
||||
|
||||
|
||||
from tqdm import tqdm
|
||||
from library import model_util
|
||||
import argparse
|
||||
from transformers import CLIPTokenizer
|
||||
import torch
|
||||
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
|
||||
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
|
||||
|
||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
|
||||
def interrogate(args):
|
||||
# いろいろ準備する
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
print(f"loading LoRA: {args.model}")
|
||||
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
|
||||
|
||||
# text encoder向けの重みがあるかチェックする:本当はlora側でやるのがいい
|
||||
has_te_weight = False
|
||||
for key in network.weights_sd.keys():
|
||||
if 'lora_te' in key:
|
||||
has_te_weight = True
|
||||
break
|
||||
if not has_te_weight:
|
||||
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
|
||||
return
|
||||
del vae
|
||||
|
||||
print("loading tokenizer")
|
||||
if args.v2:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
|
||||
else:
|
||||
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
|
||||
|
||||
text_encoder.to(DEVICE)
|
||||
text_encoder.eval()
|
||||
unet.to(DEVICE)
|
||||
unet.eval() # U-Netは呼び出さないので不要だけど
|
||||
|
||||
# トークンをひとつひとつ当たっていく
|
||||
token_id_start = 0
|
||||
token_id_end = max(tokenizer.all_special_ids)
|
||||
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
|
||||
|
||||
def get_all_embeddings(text_encoder):
|
||||
embs = []
|
||||
with torch.no_grad():
|
||||
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
|
||||
batch = []
|
||||
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
|
||||
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
|
||||
# tokens = [tid] # こちらは結果がいまひとつ
|
||||
batch.append(tokens)
|
||||
|
||||
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
|
||||
# clip skip対応
|
||||
batch = torch.tensor(batch).to(DEVICE)
|
||||
if args.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(batch)[0]
|
||||
else:
|
||||
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.to("cpu")
|
||||
|
||||
embs.extend(encoder_hidden_states)
|
||||
return torch.stack(embs)
|
||||
|
||||
print("get original text encoder embeddings.")
|
||||
orig_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
|
||||
network.to(DEVICE)
|
||||
network.eval()
|
||||
|
||||
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません(以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません)")
|
||||
print("get text encoder embeddings with lora.")
|
||||
lora_embs = get_all_embeddings(text_encoder)
|
||||
|
||||
# 比べる:とりあえず単純に差分の絶対値で
|
||||
print("comparing...")
|
||||
diffs = {}
|
||||
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
|
||||
diff = torch.mean(torch.abs(orig_emb - lora_emb))
|
||||
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
|
||||
diff = float(diff.detach().to('cpu').numpy())
|
||||
diffs[token_id_start + i] = diff
|
||||
|
||||
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
|
||||
|
||||
# 結果を表示する
|
||||
print("top 100:")
|
||||
for i, (token, diff) in enumerate(diffs_sorted[:100]):
|
||||
# if diff < 1e-6:
|
||||
# break
|
||||
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
|
||||
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--model", type=str, default=None,
|
||||
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--batch_size", type=int, default=16,
|
||||
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
|
||||
parser.add_argument("--clip_skip", type=int, default=None,
|
||||
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
|
||||
|
||||
args = parser.parse_args()
|
||||
interrogate(args)
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
@@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
# conv2d
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype):
|
||||
merged_sd = {}
|
||||
base_alphas = {} # alpha for merged model
|
||||
base_dims = {}
|
||||
|
||||
alpha = None
|
||||
dim = None
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
# get alpha and dim
|
||||
alphas = {} # alpha for current model
|
||||
dims = {} # dims for current model
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
lora_module_name = key[:key.rfind(".alpha")]
|
||||
alpha = float(lora_sd[key].detach().numpy())
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
elif "lora_down" in key:
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
dim = lora_sd[key].size()[0]
|
||||
dims[lora_module_name] = dim
|
||||
if lora_module_name not in base_dims:
|
||||
base_dims[lora_module_name] = dim
|
||||
|
||||
for lora_module_name in dims.keys():
|
||||
if lora_module_name not in alphas:
|
||||
alpha = dims[lora_module_name]
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
||||
else:
|
||||
alpha = lora_sd[key].detach().numpy()
|
||||
merged_sd[key] = lora_sd[key]
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_")]
|
||||
|
||||
base_alpha = base_alphas[lora_module_name]
|
||||
alpha = alphas[lora_module_name]
|
||||
|
||||
scale = math.sqrt(alpha / base_alpha) * ratio
|
||||
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
if "lora_down" in key:
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
merged_sd[key] = torch.tensor(alpha)
|
||||
|
||||
return merged_sd, dim, alpha
|
||||
print("merged model")
|
||||
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
|
||||
|
||||
return merged_sd
|
||||
|
||||
|
||||
def merge(args):
|
||||
@@ -152,7 +185,7 @@ def merge(args):
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
179
networks/merge_lora_old.py
Normal file
179
networks/merge_lora_old.py
Normal file
@@ -0,0 +1,179 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
text_encoder.to(merge_dtype)
|
||||
unet.to(merge_dtype)
|
||||
|
||||
# create module map
|
||||
name_to_module = {}
|
||||
for i, root_module in enumerate([text_encoder, unet]):
|
||||
if i == 0:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
|
||||
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
|
||||
else:
|
||||
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
|
||||
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
|
||||
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
name_to_module[lora_name] = child_module
|
||||
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if "lora_down" in key:
|
||||
up_key = key.replace("lora_down", "lora_up")
|
||||
alpha_key = key[:key.index("lora_down")] + 'alpha'
|
||||
|
||||
# find original module for this lora
|
||||
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
|
||||
if module_name not in name_to_module:
|
||||
print(f"no module found for LoRA weight: {key}")
|
||||
continue
|
||||
module = name_to_module[module_name]
|
||||
# print(f"apply {key} to {module}")
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
up_weight = lora_sd[up_key]
|
||||
|
||||
dim = down_weight.size()[0]
|
||||
alpha = lora_sd.get(alpha_key, dim)
|
||||
scale = alpha / dim
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
# conv2d
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, merge_dtype):
|
||||
merged_sd = {}
|
||||
|
||||
alpha = None
|
||||
dim = None
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
print(f"merging...")
|
||||
for key in lora_sd.keys():
|
||||
if 'alpha' in key:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
|
||||
else:
|
||||
alpha = lora_sd[key].detach().numpy()
|
||||
merged_sd[key] = lora_sd[key]
|
||||
else:
|
||||
if key in merged_sd:
|
||||
assert merged_sd[key].size() == lora_sd[key].size(
|
||||
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
|
||||
else:
|
||||
if "lora_down" in key:
|
||||
dim = lora_sd[key].size()[0]
|
||||
merged_sd[key] = lora_sd[key] * ratio
|
||||
|
||||
print(f"dim (rank): {dim}, alpha: {alpha}")
|
||||
if alpha is None:
|
||||
alpha = dim
|
||||
|
||||
return merged_sd, dim, alpha
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
if args.sd_model is not None:
|
||||
print(f"loading SD model: {args.sd_model}")
|
||||
|
||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
|
||||
|
||||
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving SD model to: {args.save_to}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
|
||||
args.sd_model, 0, 0, save_dtype, vae)
|
||||
else:
|
||||
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--v2", action='store_true',
|
||||
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--sd_model", type=str, default=None,
|
||||
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
@@ -5,37 +5,40 @@
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
if model_util.is_safetensors(file_name):
|
||||
sd = load_file(file_name)
|
||||
with safe_open(file_name, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
metadata = None
|
||||
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
return sd, metadata
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
if model_util.is_safetensors(file_name):
|
||||
save_file(model, file_name, metadata)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
|
||||
def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
print("Loading Model...")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
|
||||
@@ -55,7 +58,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
scale = network_alpha/network_dim
|
||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||
|
||||
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
@@ -84,7 +87,7 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
lora_down_weight = lora_down_weight.squeeze()
|
||||
lora_up_weight = lora_up_weight.squeeze()
|
||||
|
||||
if args.device:
|
||||
if device:
|
||||
org_device = lora_up_weight.device
|
||||
lora_up_weight = lora_up_weight.to(args.device)
|
||||
lora_down_weight = lora_down_weight.to(args.device)
|
||||
@@ -125,7 +128,8 @@ def resize_lora_model(model, new_rank, merge_dtype, save_dtype):
|
||||
weights_loaded = False
|
||||
|
||||
print("resizing complete")
|
||||
return o_lora_sd
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
|
||||
def resize(args):
|
||||
|
||||
@@ -143,10 +147,27 @@ def resize(args):
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype)
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
print("resizing rank...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
metadata["sshs_legacy_hash"] = legacy_hash
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
164
networks/svd_merge_lora.py
Normal file
164
networks/svd_merge_lora.py
Normal file
@@ -0,0 +1,164 @@
|
||||
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from tqdm import tqdm
|
||||
import library.model_util as model_util
|
||||
import lora
|
||||
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
sd = load_file(file_name)
|
||||
else:
|
||||
sd = torch.load(file_name, map_location='cpu')
|
||||
for key in list(sd.keys()):
|
||||
if type(sd[key]) == torch.Tensor:
|
||||
sd[key] = sd[key].to(dtype)
|
||||
return sd
|
||||
|
||||
|
||||
def save_to_file(file_name, model, state_dict, dtype):
|
||||
if dtype is not None:
|
||||
for key in list(state_dict.keys()):
|
||||
if type(state_dict[key]) == torch.Tensor:
|
||||
state_dict[key] = state_dict[key].to(dtype)
|
||||
|
||||
if os.path.splitext(file_name)[1] == '.safetensors':
|
||||
save_file(model, file_name)
|
||||
else:
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
lora_sd = load_state_dict(model, merge_dtype)
|
||||
|
||||
# merge
|
||||
print(f"merging...")
|
||||
for key in tqdm(list(lora_sd.keys())):
|
||||
if 'lora_down' not in key:
|
||||
continue
|
||||
|
||||
lora_module_name = key[:key.rfind(".lora_down")]
|
||||
|
||||
down_weight = lora_sd[key]
|
||||
network_dim = down_weight.size()[0]
|
||||
|
||||
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
|
||||
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
|
||||
|
||||
in_dim = down_weight.size()[1]
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
weight = merged_sd[lora_module_name]
|
||||
|
||||
# merge to weight
|
||||
if device:
|
||||
up_weight = up_weight.to(device)
|
||||
down_weight = down_weight.to(device)
|
||||
|
||||
# W <- W + U * D
|
||||
scale = (alpha / network_dim)
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
# extract from merged weights
|
||||
print("extract new lora...")
|
||||
merged_lora_sd = {}
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
if conv2d:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
up_weight = U
|
||||
down_weight = Vh
|
||||
|
||||
if conv2d:
|
||||
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
||||
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
||||
|
||||
return merged_lora_sd
|
||||
|
||||
|
||||
def merge(args):
|
||||
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
|
||||
|
||||
def str_to_dtype(p):
|
||||
if p == 'float':
|
||||
return torch.float
|
||||
if p == 'fp16':
|
||||
return torch.float16
|
||||
if p == 'bf16':
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
merge_dtype = str_to_dtype(args.precision)
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--save_precision", type=str, default=None,
|
||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
|
||||
parser.add_argument("--precision", type=str, default="float",
|
||||
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度(floatを推奨)")
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--models", type=str, nargs='*',
|
||||
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
|
||||
parser.add_argument("--ratios", type=float, nargs='*',
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--new_rank", type=int, default=4,
|
||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
args = parser.parse_args()
|
||||
merge(args)
|
||||
@@ -1,23 +1,24 @@
|
||||
accelerate==0.15.0
|
||||
transformers==4.26.0
|
||||
ftfy
|
||||
albumentations
|
||||
opencv-python
|
||||
einops
|
||||
ftfy==6.1.1
|
||||
albumentations==1.3.0
|
||||
opencv-python==4.7.0.68
|
||||
einops==0.6.0
|
||||
diffusers[torch]==0.10.2
|
||||
pytorch_lightning
|
||||
pytorch-lightning==1.9.0
|
||||
bitsandbytes==0.35.0
|
||||
tensorboard
|
||||
tensorboard==2.10.1
|
||||
safetensors==0.2.6
|
||||
gradio
|
||||
altair
|
||||
easygui
|
||||
gradio==3.16.2
|
||||
altair==4.2.2
|
||||
easygui==0.98.3
|
||||
# for BLIP captioning
|
||||
requests
|
||||
timm==0.4.12
|
||||
fairscale==0.4.4
|
||||
requests==2.28.2
|
||||
timm==0.6.12
|
||||
fairscale==0.4.13
|
||||
# for WD14 captioning
|
||||
tensorflow<2.11
|
||||
huggingface-hub
|
||||
# tensorflow<2.11
|
||||
tensorflow==2.10.1
|
||||
huggingface-hub==0.12.0
|
||||
# for kohya_ss library
|
||||
.
|
||||
122
tools/resize_images_to_resolution.py
Normal file
122
tools/resize_images_to_resolution.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import glob
|
||||
import os
|
||||
import cv2
|
||||
import argparse
|
||||
import shutil
|
||||
import math
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
|
||||
# Split the max_resolution string by "," and strip any whitespaces
|
||||
max_resolutions = [res.strip() for res in max_resolution.split(',')]
|
||||
|
||||
# # Calculate max_pixels from max_resolution string
|
||||
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||
|
||||
# Create destination folder if it does not exist
|
||||
if not os.path.exists(dst_img_folder):
|
||||
os.makedirs(dst_img_folder)
|
||||
|
||||
# Select interpolation method
|
||||
if interpolation == 'lanczos4':
|
||||
cv2_interpolation = cv2.INTER_LANCZOS4
|
||||
elif interpolation == 'cubic':
|
||||
cv2_interpolation = cv2.INTER_CUBIC
|
||||
else:
|
||||
cv2_interpolation = cv2.INTER_AREA
|
||||
|
||||
# Iterate through all files in src_img_folder
|
||||
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
|
||||
for filename in os.listdir(src_img_folder):
|
||||
# Check if the image is png, jpg or webp etc...
|
||||
if not filename.endswith(img_exts):
|
||||
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
|
||||
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
|
||||
continue
|
||||
|
||||
# Load image
|
||||
# img = cv2.imread(os.path.join(src_img_folder, filename))
|
||||
image = Image.open(os.path.join(src_img_folder, filename))
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
img = np.array(image, np.uint8)
|
||||
|
||||
base, _ = os.path.splitext(filename)
|
||||
for max_resolution in max_resolutions:
|
||||
# Calculate max_pixels from max_resolution string
|
||||
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
|
||||
|
||||
# Calculate current number of pixels
|
||||
current_pixels = img.shape[0] * img.shape[1]
|
||||
|
||||
# Check if the image needs resizing
|
||||
if current_pixels > max_pixels:
|
||||
# Calculate scaling factor
|
||||
scale_factor = max_pixels / current_pixels
|
||||
|
||||
# Calculate new dimensions
|
||||
new_height = int(img.shape[0] * math.sqrt(scale_factor))
|
||||
new_width = int(img.shape[1] * math.sqrt(scale_factor))
|
||||
|
||||
# Resize image
|
||||
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
|
||||
else:
|
||||
new_height, new_width = img.shape[0:2]
|
||||
|
||||
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
|
||||
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
|
||||
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
|
||||
|
||||
# Center crop the image to the calculated dimensions
|
||||
y = int((img.shape[0] - new_height) / 2)
|
||||
x = int((img.shape[1] - new_width) / 2)
|
||||
img = img[y:y + new_height, x:x + new_width]
|
||||
|
||||
# Split filename into base and extension
|
||||
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
|
||||
|
||||
# Save resized image in dst_img_folder
|
||||
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
|
||||
image = Image.fromarray(img)
|
||||
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
|
||||
|
||||
proc = "Resized" if current_pixels > max_pixels else "Saved"
|
||||
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
|
||||
|
||||
# If other files with same basename, copy them with resolution suffix
|
||||
if copy_associated_files:
|
||||
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
|
||||
for asoc_file in asoc_files:
|
||||
ext = os.path.splitext(asoc_file)[1]
|
||||
if ext in img_exts:
|
||||
continue
|
||||
for max_resolution in max_resolutions:
|
||||
new_asoc_file = base + '+' + max_resolution + ext
|
||||
print(f"Copy {asoc_file} as {new_asoc_file}")
|
||||
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
|
||||
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
|
||||
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
|
||||
parser.add_argument('--max_resolution', type=str,
|
||||
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
|
||||
parser.add_argument('--divisible_by', type=int,
|
||||
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
|
||||
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
|
||||
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
|
||||
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
|
||||
parser.add_argument('--copy_associated_files', action='store_true',
|
||||
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
|
||||
|
||||
args = parser.parse_args()
|
||||
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
|
||||
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -38,8 +38,13 @@ def train(args):
|
||||
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||
|
||||
if args.no_token_padding:
|
||||
train_dataset.disable_token_padding()
|
||||
|
||||
# 学習データのdropout率を設定する
|
||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
@@ -203,6 +208,7 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset.set_current_epoch(epoch + 1)
|
||||
|
||||
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
|
||||
unet.train()
|
||||
@@ -327,7 +333,7 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, False)
|
||||
train_util.add_dataset_arguments(parser, True, False, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
train_util.add_sd_saving_arguments(parser)
|
||||
|
||||
|
||||
@@ -133,6 +133,10 @@ def train(args):
|
||||
args.bucket_reso_steps, args.bucket_no_upscale,
|
||||
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||
args.dataset_repeats, args.debug_dataset)
|
||||
|
||||
# 学習データのdropout率を設定する
|
||||
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
|
||||
|
||||
train_dataset.make_buckets()
|
||||
|
||||
if args.debug_dataset:
|
||||
@@ -387,6 +391,8 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset.set_current_epoch(epoch + 1)
|
||||
|
||||
metadata["ss_epoch"] = str(epoch+1)
|
||||
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
@@ -521,7 +527,7 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||
|
||||
@@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
--output_dirオプションで指定したディレクトリに、LoRAのモデルが保存されます。
|
||||
--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
|
||||
|
||||
その他、以下のオプションが指定できます。
|
||||
|
||||
@@ -178,6 +178,38 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA
|
||||
- --save_precision
|
||||
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
||||
|
||||
## 画像リサイズスクリプト
|
||||
|
||||
(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
|
||||
|
||||
Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。
|
||||
|
||||
### スクリプトの実行方法
|
||||
|
||||
以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
|
||||
|
||||
```
|
||||
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
|
||||
--copy_associated_files 元画像フォルダ 変換先フォルダ
|
||||
```
|
||||
|
||||
元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
|
||||
|
||||
``--max_resolution`` オプションにリサイズ先のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
|
||||
|
||||
``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式(quality=100)で保存されます。
|
||||
|
||||
``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
|
||||
|
||||
|
||||
### その他のオプション
|
||||
|
||||
- divisible_by
|
||||
- リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
|
||||
- interpolation
|
||||
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
|
||||
|
||||
|
||||
## 追加情報
|
||||
|
||||
### cloneofsimo氏のリポジトリとの違い
|
||||
|
||||
@@ -98,12 +98,12 @@ def train(args):
|
||||
|
||||
# Convert the init_word to token_id
|
||||
if args.init_word is not None:
|
||||
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
assert len(
|
||||
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください"
|
||||
init_token_id = init_token_id[0]
|
||||
init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||
if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
|
||||
print(
|
||||
f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}")
|
||||
else:
|
||||
init_token_id = None
|
||||
init_token_ids = None
|
||||
|
||||
# add new word to tokenizer, count is num_vectors_per_token
|
||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||
@@ -120,9 +120,9 @@ def train(args):
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
if init_token_id is not None:
|
||||
for token_id in token_ids:
|
||||
token_embeds[token_id] = token_embeds[init_token_id]
|
||||
if init_token_ids is not None:
|
||||
for i, token_id in enumerate(token_ids):
|
||||
token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
|
||||
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||
|
||||
# load weights
|
||||
@@ -235,7 +235,7 @@ def train(args):
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||
print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
# Freeze all parameters except for the token embeddings in text encoder
|
||||
@@ -296,6 +296,7 @@ def train(args):
|
||||
|
||||
for epoch in range(num_train_epochs):
|
||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||
train_dataset.set_current_epoch(epoch + 1)
|
||||
|
||||
text_encoder.train()
|
||||
|
||||
@@ -383,8 +384,8 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||
d = updated_embs - bef_epo_embs
|
||||
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||
# d = updated_embs - bef_epo_embs
|
||||
# print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||
@@ -478,7 +479,7 @@ if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
train_util.add_sd_models_arguments(parser)
|
||||
train_util.add_dataset_arguments(parser, True, True)
|
||||
train_util.add_dataset_arguments(parser, True, True, False)
|
||||
train_util.add_training_arguments(parser, True)
|
||||
|
||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||
@@ -491,7 +492,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--token_string", type=str, default=None,
|
||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
|
||||
parser.add_argument("--init_word", type=str, default=None,
|
||||
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること")
|
||||
help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||
parser.add_argument("--use_object_template", action='store_true',
|
||||
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
|
||||
parser.add_argument("--use_style_template", action='store_true',
|
||||
|
||||
Reference in New Issue
Block a user