Merge branch 'kohya-ss:main' into support-multi-gpu

This commit is contained in:
Isotr0py
2023-02-12 15:06:46 +08:00
committed by GitHub
18 changed files with 1067 additions and 238 deletions

21
.github/workflows/typos.yml vendored Normal file
View File

@@ -0,0 +1,21 @@
---
# yamllint disable rule:line-length
name: Typos
on: # yamllint disable-line rule:truthy
push:
pull_request:
types:
- opened
- synchronize
- reopened
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.13.10

View File

@@ -116,7 +116,7 @@ accelerate configの質問には以下のように答えてください。bf1
cd sd-scripts cd sd-scripts
git pull git pull
.\venv\Scripts\activate .\venv\Scripts\activate
pip install --upgrade -r requirements.txt pip install --use-pep517 --upgrade -r requirements.txt
``` ```
コマンドが成功すれば新しいバージョンが使用できます。 コマンドが成功すれば新しいバージョンが使用できます。

139
README.md
View File

@@ -1,51 +1,7 @@
This repository contains training, generation and utility scripts for Stable Diffusion. This repository contains training, generation and utility scripts for Stable Diffusion.
## Updates [__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
Note: The LoRA models for SD 2.x is not supported too in Web UI.
- 6 Feb. 2023, 2023/2/6
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
- Any value greater than or equal to 1 can be specified; 64 is highly recommended and a value divisible by 8 is recommended.
- If less than 64 is specified, padding will occur within U-Net. The result is unknown.
- If you specify a value that is not divisible by 8, it will be truncated to divisible by 8 inside VAE, because the size of the latent is 1/8 of the image size.
- If ``--bucket_no_upscale`` option is specified, images smaller than the bucket size will be processed without upscaling.
- Internally, a bucket smaller than the image size is created (for example, if the image is 300x300 and ``bucket_reso_steps=64``, the bucket is 256x256). The image will be trimmed.
- Implementation of [#130](https://github.com/kohya-ss/sd-scripts/issues/130).
- Images with an area larger than the maximum size specified by ``--resolution`` are downsampled to the max bucket size.
- Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
- ``--random_crop`` now also works with buckets enabled.
- Instead of always cropping the center of the image, the image is shifted left, right, up, and down to be used as the training data. This is expected to train to the edges of the image.
- Implementation of discussion [#34](https://github.com/kohya-ss/sd-scripts/discussions/34).
- ``--bucket_reso_steps``および``--bucket_no_upscale``オプションを、学習スクリプトおよび``prepare_buckets_latents.py``に追加しました。
- ``--bucket_reso_steps``オプションでは、bucketの解像度の単位を指定できます。デフォルトは64で、今までと同じ動作です。
- 1以上の任意の値を指定できます。基本的には64を推奨します。64以外の値では、8で割り切れる値を推奨します。
- 64未満を指定するとU-Netの内部でpaddingが発生します。どのような結果になるかは未知数です。
- 8で割り切れない値を指定すると余りはVAE内部で切り捨てられます。
- ``--bucket_no_upscale``オプションを指定すると、bucketサイズよりも小さい画像は拡大せずそのまま処理します。
- 内部的には画像サイズ以下のサイズのbucketを作成しますたとえば画像が300x300で``bucket_reso_steps=64``の場合、256x256のbucket。余りは都度trimmingされます。
- [#130](https://github.com/kohya-ss/sd-scripts/issues/130) を実装したものです。
- ``--resolution``で指定した最大サイズよりも面積が大きい画像は、最大サイズと同じ面積になるようアスペクト比を維持したまま縮小され、そのサイズを元にbucketが作られます。
- これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
- たとえば繰り返し回数10で、あるbucketに1枚しか画像がなく、バッチサイズが10以上のとき、今まではepoch内で、同一画像を10枚含むバッチが1回だけ使用されていました。
- 機能追加後はepoch内にサイズ1のバッチが10回、使用されます。
- ``--random_crop``がbucketを有効にした場合にも機能するようになりました。
- 常に画像の中央を切り取るのではなく、左右、上下にずらして教師データにします。これにより画像端まで学習されることが期待されます。
- discussionの[#34](https://github.com/kohya-ss/sd-scripts/discussions/34)を実装したものです。
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
SD2.x用のLoRAモデルはサポートされないようです。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
##
[日本語版README](./README-ja.md) [日本語版README](./README-ja.md)
@@ -54,10 +10,13 @@ For easier use (GUI and PowerShell scripts etc...), please visit [the repository
This repository contains the scripts for: This repository contains the scripts for:
* DreamBooth training, including U-Net and Text Encoder * DreamBooth training, including U-Net and Text Encoder
* fine-tuning (native training), including U-Net and Text Encoder * Fine-tuning (native training), including U-Net and Text Encoder
* LoRA training * LoRA training
* image generation * Texutl Inversion training
* model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers) * Image generation
* Model conversion (supports 1.x and 2.x, Stable Diffision ckpt/safetensors and Diffusers)
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ (SD 1.x based only) Thank you for great work!!!
## About requirements.txt ## About requirements.txt
@@ -144,7 +103,7 @@ When a new release comes out you can upgrade your repo with the following comman
cd sd-scripts cd sd-scripts
git pull git pull
.\venv\Scripts\activate .\venv\Scripts\activate
pip install --upgrade -r requirements.txt pip install --use-pep517 --upgrade -r requirements.txt
``` ```
Once the commands have completed successfully you should be ready to use the new version. Once the commands have completed successfully you should be ready to use the new version.
@@ -162,3 +121,83 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT [bitsandbytes](https://github.com/TimDettmers/bitsandbytes): MIT
[BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause [BLIP](https://github.com/salesforce/BLIP): BSD-3-Clause
## Change History
- 11 Feb. 2023, 2023/2/11:
- ``lora_interrogator.py`` is added in ``networks`` folder. See ``python networks\lora_interrogator.py -h`` for usage.
- For LoRAs where the activation word is unknown, this script compares the output of Text Encoder after applying LoRA to that of unapplied to find out which token is affected by LoRA. Hopefully you can figure out the activation word. LoRA trained with captions does not seem to be able to interrogate.
- Batch size can be large (like 64 or 128).
- ``train_textual_inversion.py`` now supports multiple init words.
- Following feature is reverted to be the same as before. Sorry for confusion:
> Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
- ``lora_interrogator.py`` を ``network``フォルダに追加しました。使用法は ``python networks\lora_interrogator.py -h`` でご確認ください。
- このスクリプトは、起動promptがわからないLoRAについて、LoRA適用前後のText Encoderの出力を比較することで、どのtokenの出力が変化しているかを調べます。運が良ければ起動用の単語が分かります。キャプション付きで学習されたLoRAは影響が広範囲に及ぶため、調査は難しいようです。
- バッチサイズはわりと大きくできます64や128など
- ``train_textual_inversion.py`` で複数のinit_word指定が可能になりました。
- 次の機能を削除し元に戻しました。混乱を招き申し訳ありません。
> これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
- 10 Feb. 2023, 2023/2/10:
- Updated ``requirements.txt`` to prevent upgrading with pip taking a long time or failure to upgrade.
- ``resize_lora.py`` keeps the metadata of the model. ``dimension is resized from ...`` is added to the top of ``ss_training_comment``.
- ``merge_lora.py`` supports models with different ``alpha``s. If there is a problem, old version is ``merge_lora_old.py``.
- ``svd_merge_lora.py`` is added. This script merges LoRA models with any rank (dim) and alpha, and approximate a new LoRA with svd for a specified rank (dim).
- Note: merging scripts erase the metadata currently.
- ``resize_images_to_resolution.py`` supports multibyte characters in filenames.
- pipでの更新が長時間掛かったり、更新に失敗したりするのを防ぐため、``requirements.txt``を更新しました。
- ``resize_lora.py``がメタデータを保持するようになりました。 ``dimension is resized from ...`` という文字列が ``ss_training_comment`` の先頭に追加されます。
- ``merge_lora.py``がalphaが異なるモデルをサポートしました。 何か問題がありましたら旧バージョン ``merge_lora_old.py`` をお使いください。
- ``svd_merge_lora.py`` を追加しました。 複数の任意のdim (rank)、alphaのLoRAモデルをマージし、svdで任意dim(rank)のLoRAで近似します。
- 注:マージ系のスクリプトは現時点ではメタデータを消去しますのでご注意ください。
- ``resize_images_to_resolution.py``が日本語ファイル名をサポートしました。
- 9 Feb. 2023, 2023/2/9:
- Caption dropout is supported in ``train_db.py``, ``fine_tune.py`` and ``train_network.py``. Thanks to forestsource!
- ``--caption_dropout_rate`` option specifies the dropout rate for captions (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the image is trained with the empty caption. Default is 0 (no dropout).
- ``--caption_dropout_every_n_epochs`` option specifies how many epochs to drop captions. If ``3`` is specified, in epoch 3, 6, 9 ..., images are trained with all captions empty. Default is None (no dropout).
- ``--caption_tag_dropout_rate`` option specified the dropout rate for tags (comma separated tokens) (0~1.0, 0.1 means 10% chance for dropout). If dropout occurs, the tag is removed from the caption. If ``--keep_tokens`` option is set, these tokens (tags) are not dropped. Default is 0 (no droupout).
- The bulk image downsampling script is added. Documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) (in Jpanaese). Thanks to bmaltais!
- Typo check is added. Thanks to shirayu!
- キャプションのドロップアウトを``train_db.py``、``fine_tune.py``、``train_network.py``の各スクリプトに追加しました。forestsource氏に感謝します。
- ``--caption_dropout_rate``オプションでキャプションのドロップアウト率を指定します0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトされた場合、画像は空のキャプションで学習されます。デフォルトは 0 (ドロップアウトなし)です。
- ``--caption_dropout_every_n_epochs`` オプションで何エポックごとにキャプションを完全にドロップアウトするか指定します。たとえば``3``を指定すると、エポック3、6、9……で、すべての画像がキャプションなしで学習されます。デフォルトは None (ドロップアウトなし)です。
- ``--caption_tag_dropout_rate`` オプションで各タグカンマ区切りの各部分のドロップアウト率を指定します0~1.0、 0.1を指定すると10%の確率でドロップアウト)。ドロップアウトが起きるとそのタグはそのときだけキャプションから取り除かれて学習されます。``--keep_tokens`` オプションを指定していると、シャッフルされない部分のタグはドロップアウトされません。デフォルトは 0 (ドロップアウトなし)です。
- 画像の一括縮小スクリプトを追加しました。ドキュメントは [こちら](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E7%94%BB%E5%83%8F%E3%83%AA%E3%82%B5%E3%82%A4%E3%82%BA%E3%82%B9%E3%82%AF%E3%83%AA%E3%83%97%E3%83%88) です。bmaltais氏に感謝します。
- 誤字チェッカが追加されました。shirayu氏に感謝します。
- 6 Feb. 2023, 2023/2/6
- ``--bucket_reso_steps`` and ``--bucket_no_upscale`` options are added to training scripts (fine tuning, DreamBooth, LoRA and Textual Inversion) and ``prepare_buckets_latents.py``.
- ``--bucket_reso_steps`` takes the steps for buckets in aspect ratio bucketing. Default is 64, same as before.
- Any value greater than or equal to 1 can be specified; 64 is highly recommended and a value divisible by 8 is recommended.
- If less than 64 is specified, padding will occur within U-Net. The result is unknown.
- If you specify a value that is not divisible by 8, it will be truncated to divisible by 8 inside VAE, because the size of the latent is 1/8 of the image size.
- If ``--bucket_no_upscale`` option is specified, images smaller than the bucket size will be processed without upscaling.
- Internally, a bucket smaller than the image size is created (for example, if the image is 300x300 and ``bucket_reso_steps=64``, the bucket is 256x256). The image will be trimmed.
- Implementation of [#130](https://github.com/kohya-ss/sd-scripts/issues/130).
- Images with an area larger than the maximum size specified by ``--resolution`` are downsampled to the max bucket size.
- Now the number of data in each batch is limited to the number of actual images (not duplicated). Because a certain bucket may contain smaller number of actual images, so the batch may contain same (duplicated) images.
- ``--random_crop`` now also works with buckets enabled.
- Instead of always cropping the center of the image, the image is shifted left, right, up, and down to be used as the training data. This is expected to train to the edges of the image.
- Implementation of discussion [#34](https://github.com/kohya-ss/sd-scripts/discussions/34).
- ``--bucket_reso_steps``および``--bucket_no_upscale``オプションを、学習スクリプトおよび``prepare_buckets_latents.py``に追加しました。
- ``--bucket_reso_steps``オプションでは、bucketの解像度の単位を指定できます。デフォルトは64で、今までと同じ動作です。
- 1以上の任意の値を指定できます。基本的には64を推奨します。64以外の値では、8で割り切れる値を推奨します。
- 64未満を指定するとU-Netの内部でpaddingが発生します。どのような結果になるかは未知数です。
- 8で割り切れない値を指定すると余りはVAE内部で切り捨てられます。
- ``--bucket_no_upscale``オプションを指定すると、bucketサイズよりも小さい画像は拡大せずそのまま処理します。
- 内部的には画像サイズ以下のサイズのbucketを作成しますたとえば画像が300x300で``bucket_reso_steps=64``の場合、256x256のbucket。余りは都度trimmingされます。
- [#130](https://github.com/kohya-ss/sd-scripts/issues/130) を実装したものです。
- ``--resolution``で指定した最大サイズよりも面積が大きい画像は、最大サイズと同じ面積になるようアスペクト比を維持したまま縮小され、そのサイズを元にbucketが作られます。
- これらのオプションによりbucketが細分化され、ひとつのバッチ内に同一画像が重複して存在することが増えたため、バッチサイズを``そのbucketの画像種類数``までに制限する機能を追加しました。
- たとえば繰り返し回数10で、あるbucketに1枚しか画像がなく、バッチサイズが10以上のとき、今まではepoch内で、同一画像を10枚含むバッチが1回だけ使用されていました。
- 機能追加後はepoch内にサイズ1のバッチが10回、使用されます。
- ``--random_crop``がbucketを有効にした場合にも機能するようになりました。
- 常に画像の中央を切り取るのではなく、左右、上下にずらして教師データにします。これにより画像端まで学習されることが期待されます。
- discussionの[#34](https://github.com/kohya-ss/sd-scripts/discussions/34)を実装したものです。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。

15
_typos.toml Normal file
View File

@@ -0,0 +1,15 @@
# Files for typos
# Instruction: https://github.com/marketplace/actions/typos-action#getting-started
[default.extend-identifiers]
[default.extend-words]
NIN="NIN"
parms="parms"
nin="nin"
extention="extention" # Intentionally left
nd="nd"
[files]
extend-exclude = ["_typos.toml"]

View File

@@ -36,6 +36,10 @@ def train(args):
args.bucket_reso_steps, args.bucket_no_upscale, args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset) args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets() train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
@@ -226,6 +230,8 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
for m in training_models: for m in training_models:
m.train() m.train()
@@ -332,7 +338,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, False, True) train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False) train_util.add_training_arguments(parser, False)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)

View File

@@ -113,7 +113,7 @@ class BucketManager():
# 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく # 規定サイズから選ぶ場合の解像度、aspect ratioの情報を格納しておく
self.predefined_resos = resos.copy() self.predefined_resos = resos.copy()
self.predefined_resos_set = set(resos) self.predefined_resos_set = set(resos)
self.predifined_aspect_ratios = np.array([w / h for w, h in resos]) self.predefined_aspect_ratios = np.array([w / h for w, h in resos])
def add_if_new_reso(self, reso): def add_if_new_reso(self, reso):
if reso not in self.reso_to_id: if reso not in self.reso_to_id:
@@ -135,7 +135,7 @@ class BucketManager():
if reso in self.predefined_resos_set: if reso in self.predefined_resos_set:
pass pass
else: else:
ar_errors = self.predifined_aspect_ratios - aspect_ratio ar_errors = self.predefined_aspect_ratios - aspect_ratio
predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの predefined_bucket_id = np.abs(ar_errors).argmin() # 当該解像度以外でaspect ratio errorが最も少ないもの
reso = self.predefined_resos[predefined_bucket_id] reso = self.predefined_resos[predefined_bucket_id]
@@ -223,6 +223,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
self.current_epoch: int = 0 # インスタンスがepochごとに新しく作られるようなので外側から渡さないとダメ
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.tag_dropout_rate: float = 0
# augmentation # augmentation
flip_p = 0.5 if flip_aug else 0.0 flip_p = 0.5 if flip_aug else 0.0
if color_aug: if color_aug:
@@ -247,6 +252,15 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements = {} self.replacements = {}
def set_current_epoch(self, epoch):
self.current_epoch = epoch
def set_caption_dropout(self, dropout_rate, dropout_every_n_epochs, tag_dropout_rate):
# コンストラクタで渡さないのはTextual Inversionで意識したくないからということにしておく
self.dropout_rate = dropout_rate
self.dropout_every_n_epochs = dropout_every_n_epochs
self.tag_dropout_rate = tag_dropout_rate
def set_tag_frequency(self, dir_name, captions): def set_tag_frequency(self, dir_name, captions):
frequency_for_dir = self.tag_frequency.get(dir_name, {}) frequency_for_dir = self.tag_frequency.get(dir_name, {})
self.tag_frequency[dir_name] = frequency_for_dir self.tag_frequency[dir_name] = frequency_for_dir
@@ -264,27 +278,52 @@ class BaseDataset(torch.utils.data.Dataset):
self.replacements[str_from] = str_to self.replacements[str_from] = str_to
def process_caption(self, caption): def process_caption(self, caption):
if self.shuffle_caption: # dropoutの決定tag dropがこのメソッド内にあるのでここで行うのが良い
tokens = caption.strip().split(",") is_drop_out = self.dropout_rate > 0 and random.random() < self.dropout_rate
if self.shuffle_keep_tokens is None: is_drop_out = is_drop_out or self.dropout_every_n_epochs and self.current_epoch % self.dropout_every_n_epochs == 0
random.shuffle(tokens)
else:
if len(tokens) > self.shuffle_keep_tokens:
keep_tokens = tokens[:self.shuffle_keep_tokens]
tokens = tokens[self.shuffle_keep_tokens:]
random.shuffle(tokens)
tokens = keep_tokens + tokens
caption = ",".join(tokens).strip()
for str_from, str_to in self.replacements.items(): if is_drop_out:
if str_from == "": caption = ""
# replace all else:
if type(str_to) == list: if self.shuffle_caption or self.tag_dropout_rate > 0:
caption = random.choice(str_to) def dropout_tags(tokens):
if self.tag_dropout_rate <= 0:
return tokens
l = []
for token in tokens:
if random.random() >= self.tag_dropout_rate:
l.append(token)
return l
tokens = [t.strip() for t in caption.strip().split(",")]
if self.shuffle_keep_tokens is None:
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
else: else:
caption = str_to if len(tokens) > self.shuffle_keep_tokens:
else: keep_tokens = tokens[:self.shuffle_keep_tokens]
caption = caption.replace(str_from, str_to) tokens = tokens[self.shuffle_keep_tokens:]
if self.shuffle_caption:
random.shuffle(tokens)
tokens = dropout_tags(tokens)
tokens = keep_tokens + tokens
caption = ", ".join(tokens)
# textual inversion対応
for str_from, str_to in self.replacements.items():
if str_from == "":
# replace all
if type(str_to) == list:
caption = random.choice(str_to)
else:
caption = str_to
else:
caption = caption.replace(str_from, str_to)
return caption return caption
@@ -393,17 +432,25 @@ class BaseDataset(torch.utils.data.Dataset):
# データ参照用indexを作る。このindexはdatasetのshuffleに用いられる # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる
self.buckets_indices: List(BucketBatchIndex) = [] self.buckets_indices: List(BucketBatchIndex) = []
for bucket_index, bucket in enumerate(self.bucket_manager.buckets): for bucket_index, bucket in enumerate(self.bucket_manager.buckets):
# bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは batch_count = int(math.ceil(len(bucket) / self.batch_size))
# ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# そのためバッチサイズを画像種類までに制限する
# ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# TODO 正則化画像をepochまたがりで利用する仕組み
num_of_image_types = len(set(bucket))
bucket_batch_size = min(self.batch_size, num_of_image_types)
batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
for batch_index in range(batch_count): for batch_index in range(batch_count):
self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index)) self.buckets_indices.append(BucketBatchIndex(bucket_index, self.batch_size, batch_index))
# ↓以下はbucketごとのbatch件数があまりにも増えて混乱を招くので元に戻す
#  学習時はステップ数がランダムなので、同一画像が同一batch内にあってもそれほど悪影響はないであろう、と考えられる
#
# # bucketが細分化されることにより、ひとつのbucketに一種類の画像のみというケースが増え、つまりそれは
# # ひとつのbatchが同じ画像で占められることになるので、さすがに良くないであろう
# # そのためバッチサイズを画像種類までに制限する
# # ただそれでも同一画像が同一バッチに含まれる可能性はあるので、繰り返し回数が少ないほうがshuffleの品質は良くなることは間違いない
# # TO DO 正則化画像をepochまたがりで利用する仕組み
# num_of_image_types = len(set(bucket))
# bucket_batch_size = min(self.batch_size, num_of_image_types)
# batch_count = int(math.ceil(len(bucket) / bucket_batch_size))
# # print(bucket_index, num_of_image_types, bucket_batch_size, batch_count)
# for batch_index in range(batch_count):
# self.buckets_indices.append(BucketBatchIndex(bucket_index, bucket_batch_size, batch_index))
# ↑ここまで
self.shuffle_buckets() self.shuffle_buckets()
self._length = len(self.buckets_indices) self._length = len(self.buckets_indices)
@@ -809,6 +856,7 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0 self.num_reg_images = 0
# TODO do not record tag freq when no tag
self.set_tag_frequency(os.path.basename(json_file_name), tags_list) self.set_tag_frequency(os.path.basename(json_file_name), tags_list)
self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)} self.dataset_dirs_info[os.path.basename(json_file_name)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
@@ -907,6 +955,8 @@ class FineTuningDataset(BaseDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("Escape for exit. / Escキーで中断、終了します") print("Escape for exit. / Escキーで中断、終了します")
train_dataset.set_current_epoch(1)
k = 0 k = 0
for i, example in enumerate(train_dataset): for i, example in enumerate(train_dataset):
if example['latents'] is not None: if example['latents'] is not None:
@@ -1377,7 +1427,7 @@ def verify_training_args(args: argparse.Namespace):
print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool): def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool, support_caption: bool, support_caption_dropout: bool):
# dataset common # dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument("--shuffle_caption", action="store_true", parser.add_argument("--shuffle_caption", action="store_true",
@@ -1408,6 +1458,16 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
parser.add_argument("--bucket_no_upscale", action="store_true", parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します") help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
# いわゆるtensorのDropoutと紛らわしいのでprefixにcaptionを付けておく every_n_epochsは他と平仄を合わせてdefault Noneに
parser.add_argument("--caption_dropout_rate", type=float, default=0,
help="Rate out dropout caption(0.0~1.0) / captionをdropoutする割合")
parser.add_argument("--caption_dropout_every_n_epochs", type=int, default=None,
help="Dropout all captions every N epochs / captionを指定エポックごとにdropoutする")
parser.add_argument("--caption_tag_dropout_rate", type=float, default=0,
help="Rate out dropout comma separated tokens(0.0~1.0) / カンマ区切りのタグをdropoutする割合")
if support_dreambooth: if support_dreambooth:
# DreamBooth dataset # DreamBooth dataset
parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ") parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")

View File

@@ -5,6 +5,7 @@
import math import math
import os import os
from typing import List
import torch import torch
from library import train_util from library import train_util
@@ -98,7 +99,7 @@ class LoRANetwork(torch.nn.Module):
self.alpha = alpha self.alpha = alpha
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> List[LoRAModule]:
loras = [] loras = []
for name, module in root_module.named_modules(): for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules: if module.__class__.__name__ in target_replace_modules:

View File

@@ -0,0 +1,122 @@
from tqdm import tqdm
from library import model_util
import argparse
from transformers import CLIPTokenizer
import torch
import library.model_util as model_util
import lora
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def interrogate(args):
# いろいろ準備する
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
print(f"loading LoRA: {args.model}")
network = lora.create_network_from_weights(1.0, args.model, vae, text_encoder, unet)
# text encoder向けの重みがあるかチェックする本当はlora側でやるのがいい
has_te_weight = False
for key in network.weights_sd.keys():
if 'lora_te' in key:
has_te_weight = True
break
if not has_te_weight:
print("This LoRA does not have modules for Text Encoder, cannot interrogate / このLoRAはText Encoder向けのモジュールがないため調査できません")
return
del vae
print("loading tokenizer")
if args.v2:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
else:
tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH) # , model_max_length=max_token_length + 2)
text_encoder.to(DEVICE)
text_encoder.eval()
unet.to(DEVICE)
unet.eval() # U-Netは呼び出さないので不要だけど
# トークンをひとつひとつ当たっていく
token_id_start = 0
token_id_end = max(tokenizer.all_special_ids)
print(f"interrogate tokens are: {token_id_start} to {token_id_end}")
def get_all_embeddings(text_encoder):
embs = []
with torch.no_grad():
for token_id in tqdm(range(token_id_start, token_id_end + 1, args.batch_size)):
batch = []
for tid in range(token_id, min(token_id_end + 1, token_id + args.batch_size)):
tokens = [tokenizer.bos_token_id, tid, tokenizer.eos_token_id]
# tokens = [tid] # こちらは結果がいまひとつ
batch.append(tokens)
# batch_embs = text_encoder(torch.tensor(batch).to(DEVICE))[0].to("cpu") # bos/eosも含めたほうが差が出るようだ [:, 1]
# clip skip対応
batch = torch.tensor(batch).to(DEVICE)
if args.clip_skip is None:
encoder_hidden_states = text_encoder(batch)[0]
else:
enc_out = text_encoder(batch, output_hidden_states=True, return_dict=True)
encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.to("cpu")
embs.extend(encoder_hidden_states)
return torch.stack(embs)
print("get original text encoder embeddings.")
orig_embs = get_all_embeddings(text_encoder)
network.apply_to(text_encoder, unet, True, len(network.unet_loras) > 0)
network.to(DEVICE)
network.eval()
print("You can ignore warning messages start with '_IncompatibleKeys' (LoRA model does not have alpha because trained by older script) / '_IncompatibleKeys'の警告は無視して構いません以前のスクリプトで学習されたLoRAモデルのためalphaの定義がありません")
print("get text encoder embeddings with lora.")
lora_embs = get_all_embeddings(text_encoder)
# 比べる:とりあえず単純に差分の絶対値で
print("comparing...")
diffs = {}
for i, (orig_emb, lora_emb) in enumerate(zip(orig_embs, tqdm(lora_embs))):
diff = torch.mean(torch.abs(orig_emb - lora_emb))
# diff = torch.mean(torch.cosine_similarity(orig_emb, lora_emb, dim=1)) # うまく検出できない
diff = float(diff.detach().to('cpu').numpy())
diffs[token_id_start + i] = diff
diffs_sorted = sorted(diffs.items(), key=lambda x: -x[1])
# 結果を表示する
print("top 100:")
for i, (token, diff) in enumerate(diffs_sorted[:100]):
# if diff < 1e-6:
# break
string = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens([token]))
print(f"[{i:3d}]: {token:5d} {string:<20s}: {diff:.5f}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file / 読み込むSDのモデル、ckptまたはsafetensors")
parser.add_argument("--model", type=str, default=None,
help="LoRA model to interrogate: ckpt or safetensors file / 調査するLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--batch_size", type=int, default=16,
help="batch size for processing with Text Encoder / Text Encoderで処理するときのバッチサイズ")
parser.add_argument("--clip_skip", type=int, default=None,
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いるnは1以上")
args = parser.parse_args()
interrogate(args)

View File

@@ -1,5 +1,5 @@
import math
import argparse import argparse
import os import os
import torch import torch
@@ -85,43 +85,76 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
weight = weight + ratio * (up_weight @ down_weight) * scale weight = weight + ratio * (up_weight @ down_weight) * scale
else: else:
# conv2d # conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {} base_alphas = {} # alpha for merged model
base_dims = {}
alpha = None merged_sd = {}
dim = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
# get alpha and dim
alphas = {} # alpha for current model
dims = {} # dims for current model
for key in lora_sd.keys():
if 'alpha' in key:
lora_module_name = key[:key.rfind(".alpha")]
alpha = float(lora_sd[key].detach().numpy())
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
elif "lora_down" in key:
lora_module_name = key[:key.rfind(".lora_down")]
dim = lora_sd[key].size()[0]
dims[lora_module_name] = dim
if lora_module_name not in base_dims:
base_dims[lora_module_name] = dim
for lora_module_name in dims.keys():
if lora_module_name not in alphas:
alpha = dims[lora_module_name]
alphas[lora_module_name] = alpha
if lora_module_name not in base_alphas:
base_alphas[lora_module_name] = alpha
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
# merge
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if 'alpha' in key: if 'alpha' in key:
if key in merged_sd: continue
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else: lora_module_name = key[:key.rfind(".lora_")]
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key] base_alpha = base_alphas[lora_module_name]
alpha = alphas[lora_module_name]
scale = math.sqrt(alpha / base_alpha) * ratio
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
else: else:
if key in merged_sd: merged_sd[key] = lora_sd[key] * scale
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}") # set alpha to sd
if alpha is None: for lora_module_name, alpha in base_alphas.items():
alpha = dim key = lora_module_name + ".alpha"
merged_sd[key] = torch.tensor(alpha)
return merged_sd, dim, alpha print("merged model")
print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}")
return merged_sd
def merge(args): def merge(args):
@@ -152,7 +185,7 @@ def merge(args):
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae) args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)

179
networks/merge_lora_old.py Normal file
View File

@@ -0,0 +1,179 @@
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
import library.model_util as model_util
import lora
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
text_encoder.to(merge_dtype)
unet.to(merge_dtype)
# create module map
name_to_module = {}
for i, root_module in enumerate([text_encoder, unet]):
if i == 0:
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
else:
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
target_replace_modules = lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE
for name, module in root_module.named_modules():
if module.__class__.__name__ in target_replace_modules:
for child_name, child_module in module.named_modules():
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_')
name_to_module[lora_name] = child_module
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
if module_name not in name_to_module:
print(f"no module found for LoRA weight: {key}")
continue
module = name_to_module[module_name]
# print(f"apply {key} to {module}")
down_weight = lora_sd[key]
up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D
weight = module.weight
if len(weight.size()) == 2:
# linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
# conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight)
def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...")
for key in lora_sd.keys():
if 'alpha' in key:
if key in merged_sd:
assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else:
if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
if args.sd_model is not None:
print(f"loading SD model: {args.sd_model}")
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.sd_model)
merge_to_sd_model(text_encoder, unet, args.models, args.ratios, merge_dtype)
print(f"saving SD model to: {args.save_to}")
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae)
else:
state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--v2", action='store_true',
help='load Stable Diffusion v2.x model / Stable Diffusion 2.xのモデルを読み込む')
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
args = parser.parse_args()
merge(args)

View File

@@ -5,148 +5,169 @@
import argparse import argparse
import os import os
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm from tqdm import tqdm
from library import train_util, model_util
def load_state_dict(file_name, dtype): def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors': if model_util.is_safetensors(file_name):
sd = load_file(file_name) sd = load_file(file_name)
with safe_open(file_name, framework="pt") as f:
metadata = f.metadata()
else: else:
sd = torch.load(file_name, map_location='cpu') sd = torch.load(file_name, map_location='cpu')
metadata = None
for key in list(sd.keys()): for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor: if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype) sd[key] = sd[key].to(dtype)
return sd
return sd, metadata
def save_to_file(file_name, model, state_dict, dtype): def save_to_file(file_name, model, state_dict, dtype, metadata):
if dtype is not None: if dtype is not None:
for key in list(state_dict.keys()): for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor: if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype) state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors': if model_util.is_safetensors(file_name):
save_file(model, file_name) save_file(model, file_name, metadata)
else: else:
torch.save(model, file_name) torch.save(model, file_name)
def resize_lora_model(lora_sd, new_rank, save_dtype, device):
network_alpha = None
network_dim = None
def resize_lora_model(model, new_rank, merge_dtype, save_dtype): CLAMP_QUANTILE = 0.99
print("Loading Model...")
lora_sd = load_state_dict(model, merge_dtype)
network_alpha = None # Extract loaded lora dim and alpha
network_dim = None for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
CLAMP_QUANTILE = 0.99 scale = network_alpha/network_dim
new_alpha = float(scale*new_rank) # calculate new alpha from scale
# Extract loaded lora dim and alpha print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
for key, value in lora_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim
scale = network_alpha/network_dim lora_down_weight = None
new_alpha = float(scale*new_rank) # calculate new alpha from scale lora_up_weight = None
print(f"dimension: {network_dim}, alpha: {network_alpha}, new alpha: {new_alpha}") o_lora_sd = lora_sd.copy()
block_down_name = None
block_up_name = None
lora_down_weight = None print("resizing lora...")
lora_up_weight = None with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
o_lora_sd = lora_sd.copy() weights_loaded = (lora_down_weight is not None and lora_up_weight is not None)
block_down_name = None
block_up_name = None
print("resizing lora...") if (block_down_name == block_up_name) and weights_loaded:
with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
if 'lora_down' in key:
block_down_name = key.split(".")[0]
lora_down_weight = value
if 'lora_up' in key:
block_up_name = key.split(".")[0]
lora_up_weight = value
weights_loaded = (lora_down_weight is not None and lora_up_weight is not None) conv2d = (len(lora_down_weight.size()) == 4)
if (block_down_name == block_up_name) and weights_loaded: if conv2d:
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
conv2d = (len(lora_down_weight.size()) == 4) if device:
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
if conv2d: full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
lora_down_weight = lora_down_weight.squeeze()
lora_up_weight = lora_up_weight.squeeze()
if args.device: U, S, Vh = torch.linalg.svd(full_weight_matrix)
org_device = lora_up_weight.device
lora_up_weight = lora_up_weight.to(args.device)
lora_down_weight = lora_down_weight.to(args.device)
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight) U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
U, S, Vh = torch.linalg.svd(full_weight_matrix) Vh = Vh[:new_rank, :]
U = U[:, :new_rank] dist = torch.cat([U.flatten(), Vh.flatten()])
S = S[:new_rank] hi_val = torch.quantile(dist, CLAMP_QUANTILE)
U = U @ torch.diag(S) low_val = -hi_val
Vh = Vh[:new_rank, :] U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
dist = torch.cat([U.flatten(), Vh.flatten()]) if conv2d:
hi_val = torch.quantile(dist, CLAMP_QUANTILE) U = U.unsqueeze(2).unsqueeze(3)
low_val = -hi_val Vh = Vh.unsqueeze(2).unsqueeze(3)
U = U.clamp(low_val, hi_val) if args.device:
Vh = Vh.clamp(low_val, hi_val) U = U.to(org_device)
Vh = Vh.to(org_device)
if conv2d: o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
U = U.unsqueeze(2).unsqueeze(3) o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
Vh = Vh.unsqueeze(2).unsqueeze(3) o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
if args.device: block_down_name = None
U = U.to(org_device) block_up_name = None
Vh = Vh.to(org_device) lora_down_weight = None
lora_up_weight = None
weights_loaded = False
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous() print("resizing complete")
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous() return o_lora_sd, network_dim, new_alpha
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
block_down_name = None
block_up_name = None
lora_down_weight = None
lora_up_weight = None
weights_loaded = False
print("resizing complete")
return o_lora_sd
def resize(args): def resize(args):
def str_to_dtype(p): def str_to_dtype(p):
if p == 'float': if p == 'float':
return torch.float return torch.float
if p == 'fp16': if p == 'fp16':
return torch.float16 return torch.float16
if p == 'bf16': if p == 'bf16':
return torch.bfloat16 return torch.bfloat16
return None return None
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32 merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
save_dtype = str_to_dtype(args.save_precision) save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None: if save_dtype is None:
save_dtype = merge_dtype save_dtype = merge_dtype
state_dict = resize_lora_model(args.model, args.new_rank, merge_dtype, save_dtype) print("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print(f"saving model to: {args.save_to}") print("resizing rank...")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device)
# update metadata
if metadata is None:
metadata = {}
comment = metadata.get("ss_training_comment", "")
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
metadata["ss_network_dim"] = str(args.new_rank)
metadata["ss_network_alpha"] = str(new_alpha)
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)
if __name__ == '__main__': if __name__ == '__main__':

164
networks/svd_merge_lora.py Normal file
View File

@@ -0,0 +1,164 @@
import math
import argparse
import os
import torch
from safetensors.torch import load_file, save_file
from tqdm import tqdm
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99
def load_state_dict(file_name, dtype):
if os.path.splitext(file_name)[1] == '.safetensors':
sd = load_file(file_name)
else:
sd = torch.load(file_name, map_location='cpu')
for key in list(sd.keys()):
if type(sd[key]) == torch.Tensor:
sd[key] = sd[key].to(dtype)
return sd
def save_to_file(file_name, model, state_dict, dtype):
if dtype is not None:
for key in list(state_dict.keys()):
if type(state_dict[key]) == torch.Tensor:
state_dict[key] = state_dict[key].to(dtype)
if os.path.splitext(file_name)[1] == '.safetensors':
save_file(model, file_name)
else:
torch.save(model, file_name)
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
merged_sd = {}
for model, ratio in zip(models, ratios):
print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype)
# merge
print(f"merging...")
for key in tqdm(list(lora_sd.keys())):
if 'lora_down' not in key:
continue
lora_module_name = key[:key.rfind(".lora_down")]
down_weight = lora_sd[key]
network_dim = down_weight.size()[0]
up_weight = lora_sd[lora_module_name + '.lora_up.weight']
alpha = lora_sd.get(lora_module_name + '.alpha', network_dim)
in_dim = down_weight.size()[1]
out_dim = up_weight.size()[0]
conv2d = len(down_weight.size()) == 4
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
# make original weight if not exist
if lora_module_name not in merged_sd:
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
if device:
weight = weight.to(device)
else:
weight = merged_sd[lora_module_name]
# merge to weight
if device:
up_weight = up_weight.to(device)
down_weight = down_weight.to(device)
# W <- W + U * D
scale = (alpha / network_dim)
if not conv2d: # linear
weight = weight + ratio * (up_weight @ down_weight) * scale
else:
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
).unsqueeze(2).unsqueeze(3) * scale
merged_sd[lora_module_name] = weight
# extract from merged weights
print("extract new lora...")
merged_lora_sd = {}
with torch.no_grad():
for lora_module_name, mat in tqdm(list(merged_sd.items())):
conv2d = (len(mat.size()) == 4)
if conv2d:
mat = mat.squeeze()
U, S, Vh = torch.linalg.svd(mat)
U = U[:, :new_rank]
S = S[:new_rank]
U = U @ torch.diag(S)
Vh = Vh[:new_rank, :]
dist = torch.cat([U.flatten(), Vh.flatten()])
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
low_val = -hi_val
U = U.clamp(low_val, hi_val)
Vh = Vh.clamp(low_val, hi_val)
up_weight = U
down_weight = Vh
if conv2d:
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
return merged_lora_sd
def merge(args):
assert len(args.models) == len(args.ratios), f"number of models must be equal to number of ratios / モデルの数と重みの数は合わせてください"
def str_to_dtype(p):
if p == 'float':
return torch.float
if p == 'fp16':
return torch.float16
if p == 'bf16':
return torch.bfloat16
return None
merge_dtype = str_to_dtype(args.precision)
save_dtype = str_to_dtype(args.save_precision)
if save_dtype is None:
save_dtype = merge_dtype
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--models", type=str, nargs='*',
help="LoRA models to merge: ckpt or safetensors file / マージするLoRAモデル、ckptまたはsafetensors")
parser.add_argument("--ratios", type=float, nargs='*',
help="ratios for each model / それぞれのLoRAモデルの比率")
parser.add_argument("--new_rank", type=int, default=4,
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args()
merge(args)

View File

@@ -1,23 +1,24 @@
accelerate==0.15.0 accelerate==0.15.0
transformers==4.26.0 transformers==4.26.0
ftfy ftfy==6.1.1
albumentations albumentations==1.3.0
opencv-python opencv-python==4.7.0.68
einops einops==0.6.0
diffusers[torch]==0.10.2 diffusers[torch]==0.10.2
pytorch_lightning pytorch-lightning==1.9.0
bitsandbytes==0.35.0 bitsandbytes==0.35.0
tensorboard tensorboard==2.10.1
safetensors==0.2.6 safetensors==0.2.6
gradio gradio==3.16.2
altair altair==4.2.2
easygui easygui==0.98.3
# for BLIP captioning # for BLIP captioning
requests requests==2.28.2
timm==0.4.12 timm==0.6.12
fairscale==0.4.4 fairscale==0.4.13
# for WD14 captioning # for WD14 captioning
tensorflow<2.11 # tensorflow<2.11
huggingface-hub tensorflow==2.10.1
huggingface-hub==0.12.0
# for kohya_ss library # for kohya_ss library
. .

View File

@@ -0,0 +1,122 @@
import glob
import os
import cv2
import argparse
import shutil
import math
from PIL import Image
import numpy as np
def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divisible_by=2, interpolation=None, save_as_png=False, copy_associated_files=False):
# Split the max_resolution string by "," and strip any whitespaces
max_resolutions = [res.strip() for res in max_resolution.split(',')]
# # Calculate max_pixels from max_resolution string
# max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Create destination folder if it does not exist
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
cv2_interpolation = cv2.INTER_LANCZOS4
elif interpolation == 'cubic':
cv2_interpolation = cv2.INTER_CUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
# Check if the image is png, jpg or webp etc...
if not filename.endswith(img_exts):
# Copy the file to the destination folder if not png, jpg or webp etc (.txt or .caption or etc.)
shutil.copy(os.path.join(src_img_folder, filename), os.path.join(dst_img_folder, filename))
continue
# Load image
# img = cv2.imread(os.path.join(src_img_folder, filename))
image = Image.open(os.path.join(src_img_folder, filename))
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
base, _ = os.path.splitext(filename)
for max_resolution in max_resolutions:
# Calculate max_pixels from max_resolution string
max_pixels = int(max_resolution.split("x")[0]) * int(max_resolution.split("x")[1])
# Calculate current number of pixels
current_pixels = img.shape[0] * img.shape[1]
# Check if the image needs resizing
if current_pixels > max_pixels:
# Calculate scaling factor
scale_factor = max_pixels / current_pixels
# Calculate new dimensions
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
new_height, new_width = img.shape[0:2]
# Calculate the new height and width that are divisible by divisible_by (with/without resizing)
new_height = new_height if new_height % divisible_by == 0 else new_height - new_height % divisible_by
new_width = new_width if new_width % divisible_by == 0 else new_width - new_width % divisible_by
# Center crop the image to the calculated dimensions
y = int((img.shape[0] - new_height) / 2)
x = int((img.shape[1] - new_width) / 2)
img = img[y:y + new_height, x:x + new_width]
# Split filename into base and extension
new_filename = base + '+' + max_resolution + ('.png' if save_as_png else '.jpg')
# Save resized image in dst_img_folder
# cv2.imwrite(os.path.join(dst_img_folder, new_filename), img, [cv2.IMWRITE_JPEG_QUALITY, 100])
image = Image.fromarray(img)
image.save(os.path.join(dst_img_folder, new_filename), quality=100)
proc = "Resized" if current_pixels > max_pixels else "Saved"
print(f"{proc} image: {filename} with size {img.shape[0]}x{img.shape[1]} as {new_filename}")
# If other files with same basename, copy them with resolution suffix
if copy_associated_files:
asoc_files = glob.glob(os.path.join(src_img_folder, base + ".*"))
for asoc_file in asoc_files:
ext = os.path.splitext(asoc_file)[1]
if ext in img_exts:
continue
for max_resolution in max_resolutions:
new_asoc_file = base + '+' + max_resolution + ext
print(f"Copy {asoc_file} as {new_asoc_file}")
shutil.copy(os.path.join(src_img_folder, asoc_file), os.path.join(dst_img_folder, new_asoc_file))
def main():
parser = argparse.ArgumentParser(
description='Resize images in a folder to a specified max resolution(s) / 指定されたフォルダ内の画像を指定した最大画像サイズ(面積)以下にアスペクト比を維持したままリサイズします')
parser.add_argument('src_img_folder', type=str, help='Source folder containing the images / 元画像のフォルダ')
parser.add_argument('dst_img_folder', type=str, help='Destination folder to save the resized images / リサイズ後の画像を保存するフォルダ')
parser.add_argument('--max_resolution', type=str,
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')
args = parser.parse_args()
resize_images(args.src_img_folder, args.dst_img_folder, args.max_resolution,
args.divisible_by, args.interpolation, args.save_as_png, args.copy_associated_files)
if __name__ == '__main__':
main()

View File

@@ -38,8 +38,13 @@ def train(args):
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
args.bucket_reso_steps, args.bucket_no_upscale, args.bucket_reso_steps, args.bucket_no_upscale,
args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) args.prior_loss_weight, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
if args.no_token_padding: if args.no_token_padding:
train_dataset.disable_token_padding() train_dataset.disable_token_padding()
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets() train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
@@ -203,6 +208,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
# 指定したステップ数までText Encoderを学習するepoch最初の状態 # 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train() unet.train()
@@ -327,7 +333,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, False) train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
train_util.add_sd_saving_arguments(parser) train_util.add_sd_saving_arguments(parser)

View File

@@ -133,6 +133,10 @@ def train(args):
args.bucket_reso_steps, args.bucket_no_upscale, args.bucket_reso_steps, args.bucket_no_upscale,
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
args.dataset_repeats, args.debug_dataset) args.dataset_repeats, args.debug_dataset)
# 学習データのdropout率を設定する
train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate)
train_dataset.make_buckets() train_dataset.make_buckets()
if args.debug_dataset: if args.debug_dataset:
@@ -387,6 +391,8 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
metadata["ss_epoch"] = str(epoch+1) metadata["ss_epoch"] = str(epoch+1)
network.on_epoch_start(text_encoder, unet) network.on_epoch_start(text_encoder, unet)
@@ -521,7 +527,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True) train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")

View File

@@ -55,7 +55,7 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py
--network_module=networks.lora --network_module=networks.lora
``` ```
--output_dirオプションで指定したディレクトリに、LoRAのモデルが保存されます。 --output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
その他、以下のオプションが指定できます。 その他、以下のオプションが指定できます。
@@ -178,6 +178,38 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA
- --save_precision - --save_precision
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。 - LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
## 画像リサイズスクリプト
(のちほどドキュメントを整理しますがとりあえずここに説明を書いておきます。)
Aspect Ratio Bucketingの機能拡張で、小さな画像については拡大しないでそのまま教師データとすることが可能になりました。元の教師画像を縮小した画像を、教師データに加えると精度が向上したという報告とともに前処理用のスクリプトをいただきましたので整備して追加しました。bmaltais氏に感謝します。
### スクリプトの実行方法
以下のように指定してください。元の画像そのまま、およびリサイズ後の画像が変換先フォルダに保存されます。リサイズ後の画像には、ファイル名に ``+512x512`` のようにリサイズ先の解像度が付け加えられます(画像サイズとは異なります)。リサイズ先の解像度より小さい画像は拡大されることはありません。
```
python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256x256 --save_as_png
--copy_associated_files 元画像フォルダ 変換先フォルダ
```
元画像フォルダ内の画像ファイルが、指定した解像度(複数指定可)と同じ面積になるようにリサイズされ、変換先フォルダに保存されます。画像以外のファイルはそのままコピーされます。
``--max_resolution`` オプションにリサイズ先のサイズを例のように指定してください。面積がそのサイズになるようにリサイズします。複数指定すると、それぞれの解像度でリサイズされます。``512x512,384x384,256x256``なら、変換先フォルダの画像は、元サイズとリサイズ後サイズ×3の計4枚になります。
``--save_as_png`` オプションを指定するとpng形式で保存します。省略するとjpeg形式quality=100で保存されます。
``--copy_associated_files`` オプションを指定すると、拡張子を除き画像と同じファイル名(たとえばキャプションなど)のファイルが、リサイズ後の画像のファイル名と同じ名前でコピーされます。
### その他のオプション
- divisible_by
- リサイズ後の画像のサイズ(縦、横のそれぞれ)がこの値で割り切れるように、画像中心を切り出します。
- interpolation
- 縮小時の補完方法を指定します。``area, cubic, lanczos4``から選択可能で、デフォルトは``area``です。
## 追加情報 ## 追加情報
### cloneofsimo氏のリポジトリとの違い ### cloneofsimo氏のリポジトリとの違い

View File

@@ -98,12 +98,12 @@ def train(args):
# Convert the init_word to token_id # Convert the init_word to token_id
if args.init_word is not None: if args.init_word is not None:
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False) init_token_ids = tokenizer.encode(args.init_word, add_special_tokens=False)
assert len( if len(init_token_ids) > 1 and len(init_token_ids) != args.num_vectors_per_token:
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください" print(
init_token_id = init_token_id[0] f"token length for init words is not same to num_vectors_per_token, init words is repeated or truncated / 初期化単語のトークン長がnum_vectors_per_tokenと合わないため、繰り返しまたは切り捨てが発生します: length {len(init_token_ids)}")
else: else:
init_token_id = None init_token_ids = None
# add new word to tokenizer, count is num_vectors_per_token # add new word to tokenizer, count is num_vectors_per_token
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)] token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
@@ -120,9 +120,9 @@ def train(args):
# Initialise the newly added placeholder token with the embeddings of the initializer token # Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder.get_input_embeddings().weight.data token_embeds = text_encoder.get_input_embeddings().weight.data
if init_token_id is not None: if init_token_ids is not None:
for token_id in token_ids: for i, token_id in enumerate(token_ids):
token_embeds[token_id] = token_embeds[init_token_id] token_embeds[token_id] = token_embeds[init_token_ids[i % len(init_token_ids)]]
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min()) # print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
# load weights # load weights
@@ -235,7 +235,7 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler) text_encoder, optimizer, train_dataloader, lr_scheduler)
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0] index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
print(len(index_no_updates), torch.sum(index_no_updates)) # print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone() orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
# Freeze all parameters except for the token embeddings in text encoder # Freeze all parameters except for the token embeddings in text encoder
@@ -296,6 +296,7 @@ def train(args):
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset.set_current_epoch(epoch + 1)
text_encoder.train() text_encoder.train()
@@ -383,8 +384,8 @@ def train(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone() updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
d = updated_embs - bef_epo_embs # d = updated_embs - bef_epo_embs
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min()) # print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
if args.save_every_n_epochs is not None: if args.save_every_n_epochs is not None:
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
@@ -478,7 +479,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser) train_util.add_sd_models_arguments(parser)
train_util.add_dataset_arguments(parser, True, True) train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
@@ -491,7 +492,7 @@ if __name__ == '__main__':
parser.add_argument("--token_string", type=str, default=None, parser.add_argument("--token_string", type=str, default=None,
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること") help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
parser.add_argument("--init_word", type=str, default=None, parser.add_argument("--init_word", type=str, default=None,
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること") help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
parser.add_argument("--use_object_template", action='store_true', parser.add_argument("--use_object_template", action='store_true',
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する") help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
parser.add_argument("--use_style_template", action='store_true', parser.add_argument("--use_style_template", action='store_true',