Merge branch 'main' into textual_inversion

This commit is contained in:
Kohya S
2023-01-26 17:50:20 +09:00
14 changed files with 360 additions and 100 deletions

View File

@@ -1,7 +1,7 @@
## リポジトリについて ## リポジトリについて
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。 Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
[README in English](./README.md) [README in English](./README.md) ←更新情報はこちらにあります
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。 GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています英語ですのであわせてご覧ください。bmaltais氏に感謝します。
@@ -16,9 +16,10 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください将来的にはすべてこちらへ移すかもしれません 当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください将来的にはすべてこちらへ移すかもしれません
* note.com [環境整備とDreamBooth学習スクリプトについて](https://note.com/kohya_ss/n/nba4eceaa4594) * [DreamBooth学習について](./train_db_README-ja.md)
* [fine-tuningのガイド](./fine_tune_README_ja.md): * [fine-tuningのガイド](./fine_tune_README_ja.md):
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
* [LoRAの学習について](./train_network_README-ja.md)
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e) * note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad) * note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
@@ -44,12 +45,11 @@ PowerShellを使う場合、venvを使えるようにするためには以下の
通常の管理者ではないPowerShellを開き以下を順に実行します。 通常の管理者ではないPowerShellを開き以下を順に実行します。
```powershell ```powershell
git clone https://github.com/kohya-ss/sd-scripts.git git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts cd sd-scripts
python -m venv --system-site-packages venv python -m venv venv
.\venv\Scripts\activate .\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -70,7 +70,7 @@ accelerate config
git clone https://github.com/kohya-ss/sd-scripts.git git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts cd sd-scripts
python -m venv --system-site-packages venv python -m venv venv
.\venv\Scripts\activate .\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -84,6 +84,8 @@ copy /y .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cud
accelerate config accelerate config
``` ```
(注:``python -m venv venv`` のほうが ``python -m venv --system-site-packages venv`` より安全そうなため書き換えました。globalなpythonにパッケージがインストールしてあると、後者だといろいろと問題が起きます。
accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。 accelerate configの質問には以下のように答えてください。bf16で学習する場合、最後の質問にはbf16と答えてください。
※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。 ※0.15.0から日本語環境では選択のためにカーソルキーを押すと落ちます……。数字キーの0、1、2……で選択できますので、そちらを使ってください。
@@ -99,7 +101,11 @@ accelerate configの質問には以下のように答えてください。bf1
``` ```
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問 ※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問
``What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。 ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``に「0」と答えてください。id `0`のGPUが使われます。
### PyTorchとxformersのバージョンについて
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
## アップグレード ## アップグレード

View File

@@ -2,16 +2,30 @@ This repository contains training, generation and utility scripts for Stable Dif
## Updates ## Updates
- 15 Jan. 2023, 2023/1/15 __Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
- Added ``--max_train_epochs`` and ``--max_data_loader_n_workers`` option for each training script.
- If you specify the number of training epochs with ``--max_train_epochs``, the number of steps is calculated from the number of epochs automatically.
- You can set the number of workers for DataLoader with ``--max_data_loader_n_workers``, default is 8. The lower number may reduce the main memory usage and the time between epochs, but may cause slower dataloading (training).
- ``--max_train_epochs`` と ``--max_data_loader_n_workers`` のオプションが学習スクリプトに追加されました。
- ``--max_train_epochs`` で学習したいエポック数を指定すると、必要なステップ数が自動的に計算され設定されます。
- ``--max_data_loader_n_workers`` で DataLoader の worker 数が指定できますデフォルトは8。値を小さくするとメインメモリの使用量が減り、エポック間の待ち時間も短くなるようです。ただしデータ読み込み学習時間は長くなる可能性があります。
Please read [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) for recent updates. Note: The LoRA models for SD 2.x is not supported too in Web UI.
最近の更新情報は [release version 0.3.0](https://github.com/kohya-ss/sd-scripts/releases/tag/v0.3.0) をご覧ください。
- 24 Jan. 2023, 2023/1/24
- Change the default save format to ``.safetensors`` for ``train_network.py``.
- Add ``--save_n_epoch_ratio`` option to specify how often to save. Thanks to forestsource!
- For example, if 5 is specified, 5 (or 6) files will be saved in training.
- Add feature to pre-caclulate hash to reduce loading time in the extension. Thanks to space-nuko!
- Add bucketing matadata. Thanks to space-nuko!
- Fix an error with bf16 model in ``gen_img_diffusers.py``.
- ``train_network.py`` のモデル保存形式のデフォルトを ``.safetensors`` に変更しました。
- モデルを保存する頻度を指定する ``--save_n_epoch_ratio`` オプションが追加されました。forestsource氏に感謝します。
- たとえば 5 を指定すると、学習終了までに合計で5個または6個のファイルが保存されます。
- 拡張でモデル読み込み時間を短縮するためのハッシュ事前計算の機能を追加しました。space-nuko氏に感謝します。
- メタデータにbucket情報が追加されました。space-nuko氏に感謝します。
- ``gen_img_diffusers.py`` でbf16形式のモデルを読み込んだときのエラーを修正しました。
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
SD2.x用のLoRAモデルはサポートされないようです。
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
## ##
@@ -65,7 +79,7 @@ Open a regular Powershell terminal and type the following inside:
git clone https://github.com/kohya-ss/sd-scripts.git git clone https://github.com/kohya-ss/sd-scripts.git
cd sd-scripts cd sd-scripts
python -m venv --system-site-packages venv python -m venv venv
.\venv\Scripts\activate .\venv\Scripts\activate
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116
@@ -77,9 +91,10 @@ cp .\bitsandbytes_windows\cextension.py .\venv\Lib\site-packages\bitsandbytes\ce
cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py cp .\bitsandbytes_windows\main.py .\venv\Lib\site-packages\bitsandbytes\cuda_setup\main.py
accelerate config accelerate config
``` ```
update: ``python -m venv venv`` is seemed to be safer than ``python -m venv --system-site-packages venv`` (some user have packages in global python).
Answers to accelerate config: Answers to accelerate config:
```txt ```txt
@@ -92,11 +107,16 @@ Answers to accelerate config:
- fp16 - fp16
``` ```
note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occured in training. In this case, answer `0` for the 6th question: note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is occurred in training. In this case, answer `0` for the 6th question:
``What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:`` ``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``
(Single GPU with id `0` will be used.) (Single GPU with id `0` will be used.)
### about PyTorch and xformers
Other versions of PyTorch and xformers seem to have problems with training.
If there is no other reason, please install the specified version.
## Upgrade ## Upgrade
When a new release comes out you can upgrade your repo with the following command: When a new release comes out you can upgrade your repo with the following command:

View File

@@ -200,6 +200,8 @@ def train(args):
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

View File

@@ -324,7 +324,7 @@ __※引数を都度書き換えて、別のメタデータファイルに書き
## 学習の実行 ## 学習の実行
たとえば以下のように実行します。以下は省メモリ化のための設定です。 たとえば以下のように実行します。以下は省メモリ化のための設定です。
``` ```
accelerate launch --num_cpu_threads_per_process 8 fine_tune.py accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
--pretrained_model_name_or_path=model.ckpt --pretrained_model_name_or_path=model.ckpt
--in_json meta_lat.json --in_json meta_lat.json
--train_data_dir=train_data --train_data_dir=train_data
@@ -336,7 +336,7 @@ accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
--save_every_n_epochs=4 --save_every_n_epochs=4
``` ```
accelerateのnum_cpu_threads_per_processにはCPUのコア数を指定するとよいようです。 accelerateのnum_cpu_threads_per_processには通常は1を指定するとよいようです。
pretrained_model_name_or_pathに学習対象のモデルを指定しますStable DiffusionのcheckpointかDiffusersのモデル。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています拡張子で自動判定 pretrained_model_name_or_pathに学習対象のモデルを指定しますStable DiffusionのcheckpointかDiffusersのモデル。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています拡張子で自動判定

View File

@@ -1981,7 +1981,6 @@ def main(args):
imported_module = importlib.import_module(network_module) imported_module = importlib.import_module(network_module)
network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i] network_mul = 1.0 if args.network_mul is None or len(args.network_mul) <= i else args.network_mul[i]
network_dim = None if args.network_dim is None or len(args.network_dim) <= i else args.network_dim[i]
net_kwargs = {} net_kwargs = {}
if args.network_args and i < len(args.network_args): if args.network_args and i < len(args.network_args):
@@ -1992,22 +1991,22 @@ def main(args):
key, value = net_arg.split("=") key, value = net_arg.split("=")
net_kwargs[key] = value net_kwargs[key] = value
network = imported_module.create_network(network_mul, network_dim, vae, text_encoder, unet, **net_kwargs)
if network is None:
return
if args.network_weights and i < len(args.network_weights): if args.network_weights and i < len(args.network_weights):
network_weight = args.network_weights[i] network_weight = args.network_weights[i]
print("load network weights from:", network_weight) print("load network weights from:", network_weight)
if os.path.splitext(network_weight)[1] == '.safetensors': if model_util.is_safetensors(network_weight):
from safetensors.torch import safe_open from safetensors.torch import safe_open
with safe_open(network_weight, framework="pt") as f: with safe_open(network_weight, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
if metadata is not None: if metadata is not None:
print(f"metadata for: {network_weight}: {metadata}") print(f"metadata for: {network_weight}: {metadata}")
network.load_weights(network_weight) network = imported_module.create_network_from_weights(network_mul, network_weight, vae, text_encoder, unet, **net_kwargs)
else:
raise ValueError("No weight. Weight is required.")
if network is None:
return
network.apply_to(text_encoder, unet) network.apply_to(text_encoder, unet)
@@ -2518,16 +2517,14 @@ if __name__ == '__main__':
parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する') parser.add_argument("--bf16", action='store_true', help='use bfloat16 / bfloat16を指定し省メモリ化する')
parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する') parser.add_argument("--xformers", action='store_true', help='use xformers / xformersを使用し高速化する')
parser.add_argument("--diffusers_xformers", action='store_true', parser.add_argument("--diffusers_xformers", action='store_true',
help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用するHypernetwork利用不可') help='use xformers by diffusers (Hypernetworks doesn\'t work) / Diffusersでxformersを使用するHypernetwork利用不可')
parser.add_argument("--opt_channels_last", action='store_true', parser.add_argument("--opt_channels_last", action='store_true',
help='set channels last option to model / モデルにchannles lastを指定し最適化する') help='set channels last option to model / モデルにchannels lastを指定し最適化する')
parser.add_argument("--network_module", type=str, default=None, nargs='*', parser.add_argument("--network_module", type=str, default=None, nargs='*',
help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名') help='Hypernetwork module to use / Hypernetworkを使う時そのモジュール名')
parser.add_argument("--network_weights", type=str, default=None, nargs='*', parser.add_argument("--network_weights", type=str, default=None, nargs='*',
help='Hypernetwork weights to load / Hypernetworkの重み') help='Hypernetwork weights to load / Hypernetworkの重み')
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率') parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
parser.add_argument("--network_dim", type=int, default=None, nargs='*',
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う') parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')

View File

@@ -11,6 +11,8 @@ import glob
import math import math
import os import os
import random import random
import hashlib
from io import BytesIO
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -24,6 +26,7 @@ from PIL import Image
import cv2 import cv2
from einops import rearrange from einops import rearrange
from torch import einsum from torch import einsum
import safetensors.torch
import library.model_util as model_util import library.model_util as model_util
@@ -79,6 +82,12 @@ class BaseDataset(torch.utils.data.Dataset):
self.debug_dataset = debug_dataset self.debug_dataset = debug_dataset
self.random_crop = random_crop self.random_crop = random_crop
self.token_padding_disabled = False self.token_padding_disabled = False
self.dataset_dirs_info = {}
self.reg_dataset_dirs_info = {}
self.enable_bucket = False
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_info = None
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2 self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
@@ -227,11 +236,17 @@ class BaseDataset(torch.utils.data.Dataset):
self.buckets[bucket_index].append(image_info.image_key) self.buckets[bucket_index].append(image_info.image_key)
if self.enable_bucket: if self.enable_bucket:
self.bucket_info = {"buckets": {}}
print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む") print("number of images (including repeats) / 各bucketの画像枚数繰り返し回数を含む")
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)): for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}") print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
img_ar_errors = np.array(img_ar_errors) img_ar_errors = np.array(img_ar_errors)
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}") mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
print(f"mean ar error (without repeats): {mean_img_ar_error}")
# 参照用indexを作る # 参照用indexを作る
self.buckets_indices: list(BucketBatchIndex) = [] self.buckets_indices: list(BucketBatchIndex) = []
@@ -479,6 +494,8 @@ class DreamBoothDataset(BaseDataset):
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso) (self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
else: else:
self.bucket_resos = [(self.width, self.height)] self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height] self.bucket_aspect_ratios = [self.width / self.height]
@@ -539,6 +556,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, False, img_path) info = ImageInfo(img_path, n_repeats, caption, False, img_path)
self.register_image(info) self.register_image(info)
self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_train_images} train images with repeating.") print(f"{num_train_images} train images with repeating.")
self.num_train_images = num_train_images self.num_train_images = num_train_images
@@ -555,6 +573,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption in zip(img_paths, captions): for img_path, caption in zip(img_paths, captions):
info = ImageInfo(img_path, n_repeats, caption, True, img_path) info = ImageInfo(img_path, n_repeats, caption, True, img_path)
reg_infos.append(info) reg_infos.append(info)
self.reg_dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)}
print(f"{num_reg_images} reg images.") print(f"{num_reg_images} reg images.")
if num_train_images < num_reg_images: if num_train_images < num_reg_images:
@@ -627,6 +646,8 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = len(metadata) * dataset_repeats self.num_train_images = len(metadata) * dataset_repeats
self.num_reg_images = 0 self.num_reg_images = 0
self.dataset_dirs_info[os.path.basename(self.train_data_dir)] = {"n_repeats": dataset_repeats, "img_count": len(metadata)}
# check existence of all npz files # check existence of all npz files
if not self.color_aug: if not self.color_aug:
npz_any = False npz_any = False
@@ -669,6 +690,8 @@ class FineTuningDataset(BaseDataset):
assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions( self.bucket_resos, self.bucket_aspect_ratios = model_util.make_bucket_resolutions(
(self.width, self.height), min_bucket_reso, max_bucket_reso) (self.width, self.height), min_bucket_reso, max_bucket_reso)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
else: else:
self.bucket_resos = [(self.width, self.height)] self.bucket_resos = [(self.width, self.height)]
self.bucket_aspect_ratios = [self.width / self.height] self.bucket_aspect_ratios = [self.width / self.height]
@@ -681,6 +704,9 @@ class FineTuningDataset(BaseDataset):
self.bucket_resos.sort() self.bucket_resos.sort()
self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos] self.bucket_aspect_ratios = [w / h for w, h in self.bucket_resos]
self.min_bucket_reso = min([min(reso) for reso in resos])
self.max_bucket_reso = max([max(reso) for reso in resos])
def image_key_to_npz_file(self, image_key): def image_key_to_npz_file(self, image_key):
base_name = os.path.splitext(image_key)[0] base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + '.npz' npz_file_norm = base_name + '.npz'
@@ -767,9 +793,9 @@ def default(val, d):
def model_hash(filename): def model_hash(filename):
"""Old model hash used by stable-diffusion-webui"""
try: try:
with open(filename, "rb") as file: with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256() m = hashlib.sha256()
file.seek(0x100000) file.seek(0x100000)
@@ -779,6 +805,61 @@ def model_hash(filename):
return 'NOFILE' return 'NOFILE'
def calculate_sha256(filename):
"""New model hash used by stable-diffusion-webui"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
def precalculate_safetensors_hashes(tensors, metadata):
"""Precalculate the model hashes needed by sd-webui-additional-networks to
save time on indexing the model later."""
# Because writing user metadata to the file can change the result of
# sd_models.model_hash(), only retain the training metadata for purposes of
# calculating the hash, as they are meant to be immutable
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
bytes = safetensors.torch.save(tensors, metadata)
b = BytesIO(bytes)
model_hash = addnet_hash_safetensors(b)
legacy_hash = addnet_hash_legacy(b)
return model_hash, legacy_hash
def addnet_hash_legacy(b):
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
m = hashlib.sha256()
b.seek(0x100000)
m.update(b.read(0x10000))
return m.hexdigest()[0:8]
def addnet_hash_safetensors(b):
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
hash_sha256 = hashlib.sha256()
blksize = 1024 * 1024
b.seek(0)
header = b.read(8)
n = int.from_bytes(header, "little")
offset = n + 8
b.seek(offset)
for chunk in iter(lambda: b.read(blksize), b""):
hash_sha256.update(chunk)
return hash_sha256.hexdigest()
# flash attention forwards and backwards # flash attention forwards and backwards
# https://arxiv.org/abs/2205.14135 # https://arxiv.org/abs/2205.14135
@@ -1046,7 +1127,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する") choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
parser.add_argument("--save_every_n_epochs", type=int, default=None, parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する") help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存するたとえば5を指定すると最低5個のファイルが保存される")
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する") parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
parser.add_argument("--save_state", action="store_true", parser.add_argument("--save_state", action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する") help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate") parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
@@ -1065,8 +1150,10 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率") parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数") parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
parser.add_argument("--max_train_epochs", type=int, default=None, help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします") parser.add_argument("--max_train_epochs", type=int, default=None,
parser.add_argument("--max_data_loader_n_workers", type=int, default=8, help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります)") help="training epochs (overrides max_train_steps) / 学習エポック数max_train_stepsを上書きします)")
parser.add_argument("--max_data_loader_n_workers", type=int, default=8,
help="max num workers for DataLoader (lower is less main RAM usage, faster epoch start and slower data loading) / DataLoaderの最大プロセス数小さい値ではメインメモリの使用量が減りエポック間の待ち時間が減りますが、データ読み込みは遅くなります")
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument("--gradient_checkpointing", action="store_true", parser.add_argument("--gradient_checkpointing", action="store_true",
help="enable gradient checkpointing / grandient checkpointingを有効にする") help="enable gradient checkpointing / grandient checkpointingを有効にする")
@@ -1316,7 +1403,6 @@ def get_epoch_ckpt_name(args: argparse.Namespace, use_safetensors, epoch):
def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int): def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoch_no: int, num_train_epochs: int):
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
remove_epoch_no = None
if saving: if saving:
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
save_func() save_func()
@@ -1324,7 +1410,7 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
if args.save_last_n_epochs is not None: if args.save_last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs remove_epoch_no = epoch_no - args.save_every_n_epochs * args.save_last_n_epochs
remove_old_func(remove_epoch_no) remove_old_func(remove_epoch_no)
return saving, remove_epoch_no return saving
def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae): def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, num_train_epochs: int, global_step: int, text_encoder, unet, vae):
@@ -1364,15 +1450,18 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
save_func = save_du save_func = save_du
remove_old_func = remove_du remove_old_func = remove_du
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs) saving = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
if saving and args.save_state: if saving and args.save_state:
save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no) save_state_on_epoch_end(args, accelerator, model_name, epoch_no)
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no): def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no):
print("saving state.") print("saving state.")
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no))) accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
if remove_epoch_no is not None:
last_n_epochs = args.save_last_n_epochs_state if args.save_last_n_epochs_state else args.save_last_n_epochs
if last_n_epochs is not None:
remove_epoch_no = epoch_no - args.save_every_n_epochs * last_n_epochs
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no)) state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
if os.path.exists(state_dir_old): if os.path.exists(state_dir_old):
print(f"removing old state: {state_dir_old}") print(f"removing old state: {state_dir_old}")

View File

@@ -0,0 +1,32 @@
import argparse
import os
import torch
from safetensors.torch import load_file
def main(file):
print(f"loading: {file}")
if os.path.splitext(file)[1] == '.safetensors':
sd = load_file(file)
else:
sd = torch.load(file, map_location='cpu')
values = []
keys = list(sd.keys())
for key in keys:
if 'lora_up' in key or 'lora_down' in key:
values.append((key, sd[key]))
print(f"number of LoRA modules: {len(values)}")
for key, value in values:
value = value.to(torch.float32)
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("file", type=str, help="model file to check / 重みを確認するモデルファイル")
args = parser.parse_args()
main(args.file)

View File

@@ -44,9 +44,9 @@ def svd(args):
print(f"loading SD model : {args.model_tuned}") print(f"loading SD model : {args.model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned) text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
# create LoRA network to extract weights # create LoRA network to extract weights: Use dim (rank) as alpha
lora_network_o = lora.create_network(1.0, args.dim, None, text_encoder_o, unet_o) lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
lora_network_t = lora.create_network(1.0, args.dim, None, text_encoder_t, unet_t) lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
assert len(lora_network_o.text_encoder_loras) == len( assert len(lora_network_o.text_encoder_loras) == len(
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース " lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違いますSD1.xベースとSD2.xベース "
@@ -116,6 +116,9 @@ def svd(args):
print(f"LoRA has {len(lora_sd)} weights.") print(f"LoRA has {len(lora_sd)} weights.")
for key in list(lora_sd.keys()): for key in list(lora_sd.keys()):
if "alpha" in key:
continue
lora_name = key.split('.')[0] lora_name = key.split('.')[0]
i = 0 if "lora_up" in key else 1 i = 0 if "lora_up" in key else 1
@@ -124,7 +127,7 @@ def svd(args):
if len(lora_sd[key].size()) == 4: if len(lora_sd[key].size()) == 4:
weights = weights.unsqueeze(2).unsqueeze(3) weights = weights.unsqueeze(2).unsqueeze(3)
assert weights.size() == lora_sd[key].size() assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
lora_sd[key] = weights lora_sd[key] = weights
# load state dict to LoRA and save it # load state dict to LoRA and save it
@@ -135,7 +138,10 @@ def svd(args):
if dir_name and not os.path.exists(dir_name): if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True) os.makedirs(dir_name, exist_ok=True)
lora_network_o.save_weights(args.save_to, save_dtype, {}) # minimum metadata
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
print(f"LoRA weights are saved to: {args.save_to}") print(f"LoRA weights are saved to: {args.save_to}")
@@ -151,8 +157,8 @@ if __name__ == '__main__':
help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors") help="Stable Diffusion tuned model, LoRA is difference of `original to tuned`: ckpt or safetensors file / 派生モデル生成されるLoRAは元→派生の差分になります、ckptまたはsafetensors")
parser.add_argument("--save_to", type=str, default=None, parser.add_argument("--save_to", type=str, default=None,
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors") help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
parser.add_argument("--dim", type=int, default=4, help="dimension of LoRA (default 4) / LoRAの次元数デフォルト4") parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数rankデフォルト4")
parser.add_argument("--device", type=str, default=None, help="device to use, 'cuda' for GPU / 計算を行うデバイス、'cuda'でGPUを使う") parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
args = parser.parse_args() args = parser.parse_args()
svd(args) svd(args)

View File

@@ -7,15 +7,19 @@ import math
import os import os
import torch import torch
from library import train_util
class LoRAModule(torch.nn.Module): class LoRAModule(torch.nn.Module):
""" """
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
""" """
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4): def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
""" if alpha == 0 or None, alpha is rank (no scaling). """
super().__init__() super().__init__()
self.lora_name = lora_name self.lora_name = lora_name
self.lora_dim = lora_dim
if org_module.__class__.__name__ == 'Conv2d': if org_module.__class__.__name__ == 'Conv2d':
in_dim = org_module.in_channels in_dim = org_module.in_channels
@@ -28,6 +32,12 @@ class LoRAModule(torch.nn.Module):
self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False) self.lora_down = torch.nn.Linear(in_dim, lora_dim, bias=False)
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False) self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
if type(alpha) == torch.Tensor:
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
alpha = lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
# same as microsoft's # same as microsoft's
torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
torch.nn.init.zeros_(self.lora_up.weight) torch.nn.init.zeros_(self.lora_up.weight)
@@ -41,13 +51,37 @@ class LoRAModule(torch.nn.Module):
del self.org_module del self.org_module
def forward(self, x): def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def create_network(multiplier, network_dim, vae, text_encoder, unet, **kwargs): def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
if network_dim is None: if network_dim is None:
network_dim = 4 # default network_dim = 4 # default
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim) network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
return network
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwargs):
if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import load_file, safe_open
weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location='cpu')
# get dim (rank)
network_alpha = None
network_dim = None
for key, value in weights_sd.items():
if network_alpha is None and 'alpha' in key:
network_alpha = value
if network_dim is None and 'lora_down' in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is None:
network_alpha = network_dim
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, lora_dim=network_dim, alpha=network_alpha)
network.weights_sd = weights_sd
return network return network
@@ -57,10 +91,11 @@ class LoRANetwork(torch.nn.Module):
LORA_PREFIX_UNET = 'lora_unet' LORA_PREFIX_UNET = 'lora_unet'
LORA_PREFIX_TEXT_ENCODER = 'lora_te' LORA_PREFIX_TEXT_ENCODER = 'lora_te'
def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4) -> None: def __init__(self, text_encoder, unet, multiplier=1.0, lora_dim=4, alpha=1) -> None:
super().__init__() super().__init__()
self.multiplier = multiplier self.multiplier = multiplier
self.lora_dim = lora_dim self.lora_dim = lora_dim
self.alpha = alpha
# create module instances # create module instances
def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]: def create_modules(prefix, root_module: torch.nn.Module, target_replace_modules) -> list[LoRAModule]:
@@ -71,7 +106,7 @@ class LoRANetwork(torch.nn.Module):
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)): if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
lora_name = prefix + '.' + name + '.' + child_name lora_name = prefix + '.' + name + '.' + child_name
lora_name = lora_name.replace('.', '_') lora_name = lora_name.replace('.', '_')
lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim) lora = LoRAModule(lora_name, child_module, self.multiplier, self.lora_dim, self.alpha)
loras.append(lora) loras.append(lora)
return loras return loras
@@ -149,21 +184,21 @@ class LoRANetwork(torch.nn.Module):
return params return params
self.requires_grad_(True) self.requires_grad_(True)
params = [] all_params = []
if self.text_encoder_loras: if self.text_encoder_loras:
param_data = {'params': enumerate_params(self.text_encoder_loras)} param_data = {'params': enumerate_params(self.text_encoder_loras)}
if text_encoder_lr is not None: if text_encoder_lr is not None:
param_data['lr'] = text_encoder_lr param_data['lr'] = text_encoder_lr
params.append(param_data) all_params.append(param_data)
if self.unet_loras: if self.unet_loras:
param_data = {'params': enumerate_params(self.unet_loras)} param_data = {'params': enumerate_params(self.unet_loras)}
if unet_lr is not None: if unet_lr is not None:
param_data['lr'] = unet_lr param_data['lr'] = unet_lr
params.append(param_data) all_params.append(param_data)
return params return all_params
def prepare_grad_etc(self, text_encoder, unet): def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True) self.requires_grad_(True)
@@ -188,6 +223,14 @@ class LoRANetwork(torch.nn.Module):
if os.path.splitext(file)[1] == '.safetensors': if os.path.splitext(file)[1] == '.safetensors':
from safetensors.torch import save_file from safetensors.torch import save_file
# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
save_file(state_dict, file, metadata) save_file(state_dict, file, metadata)
else: else:
torch.save(state_dict, file) torch.save(state_dict, file)

View File

@@ -61,6 +61,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
for key in lora_sd.keys(): for key in lora_sd.keys():
if "lora_down" in key: if "lora_down" in key:
up_key = key.replace("lora_down", "lora_up") up_key = key.replace("lora_down", "lora_up")
alpha_key = key[:key.index("lora_down")] + 'alpha'
# find original module for this lora # find original module for this lora
module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight" module_name = '.'.join(key.split('.')[:-2]) # remove trailing ".lora_down.weight"
@@ -73,14 +74,18 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
down_weight = lora_sd[key] down_weight = lora_sd[key]
up_weight = lora_sd[up_key] up_weight = lora_sd[up_key]
dim = down_weight.size()[0]
alpha = lora_sd.get(alpha_key, dim)
scale = alpha / dim
# W <- W + U * D # W <- W + U * D
weight = module.weight weight = module.weight
if len(weight.size()) == 2: if len(weight.size()) == 2:
# linear # linear
weight = weight + ratio * (up_weight @ down_weight) weight = weight + ratio * (up_weight @ down_weight) * scale
else: else:
# conv2d # conv2d
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) * scale
module.weight = torch.nn.Parameter(weight) module.weight = torch.nn.Parameter(weight)
@@ -88,20 +93,35 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
def merge_lora_models(models, ratios, merge_dtype): def merge_lora_models(models, ratios, merge_dtype):
merged_sd = {} merged_sd = {}
alpha = None
dim = None
for model, ratio in zip(models, ratios): for model, ratio in zip(models, ratios):
print(f"loading: {model}") print(f"loading: {model}")
lora_sd = load_state_dict(model, merge_dtype) lora_sd = load_state_dict(model, merge_dtype)
print(f"merging...") print(f"merging...")
for key in lora_sd.keys(): for key in lora_sd.keys():
if key in merged_sd: if 'alpha' in key:
assert merged_sd[key].size() == lora_sd[key].size( if key in merged_sd:
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" assert merged_sd[key] == lora_sd[key], f"alpha mismatch / alphaが異なる場合、現時点ではマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio else:
alpha = lora_sd[key].detach().numpy()
merged_sd[key] = lora_sd[key]
else: else:
merged_sd[key] = lora_sd[key] * ratio if key in merged_sd:
assert merged_sd[key].size() == lora_sd[key].size(
), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません"
merged_sd[key] = merged_sd[key] + lora_sd[key] * ratio
else:
if "lora_down" in key:
dim = lora_sd[key].size()[0]
merged_sd[key] = lora_sd[key] * ratio
return merged_sd print(f"dim (rank): {dim}, alpha: {alpha}")
if alpha is None:
alpha = dim
return merged_sd, dim, alpha
def merge(args): def merge(args):
@@ -132,7 +152,7 @@ def merge(args):
model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet, model_util.save_stable_diffusion_checkpoint(args.v2, args.save_to, text_encoder, unet,
args.sd_model, 0, 0, save_dtype, vae) args.sd_model, 0, 0, save_dtype, vae)
else: else:
state_dict = merge_lora_models(args.models, args.ratios, merge_dtype) state_dict, _, _ = merge_lora_models(args.models, args.ratios, merge_dtype)
print(f"saving model to: {args.save_to}") print(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype) save_to_file(args.save_to, state_dict, state_dict, save_dtype)
@@ -145,7 +165,7 @@ if __name__ == '__main__':
parser.add_argument("--save_precision", type=str, default=None, parser.add_argument("--save_precision", type=str, default=None,
choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ") choices=[None, "float", "fp16", "bf16"], help="precision in saving, same to merging if omitted / 保存時に精度を変更して保存する、省略時はマージ時の精度と同じ")
parser.add_argument("--precision", type=str, default="float", parser.add_argument("--precision", type=str, default="float",
choices=["float", "fp16", "bf16"], help="precision in merging / マージの計算時の精度") choices=["float", "fp16", "bf16"], help="precision in merging (float is recommended) / マージの計算時の精度floatを推奨")
parser.add_argument("--sd_model", type=str, default=None, parser.add_argument("--sd_model", type=str, default=None,
help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする") help="Stable Diffusion model to load: ckpt or safetensors file, merge LoRA models if omitted / 読み込むモデル、ckptまたはsafetensors。省略時はLoRAモデル同士をマージする")
parser.add_argument("--save_to", type=str, default=None, parser.add_argument("--save_to", type=str, default=None,

View File

@@ -92,10 +92,7 @@ def train(args):
gc.collect() gc.collect()
# 学習を準備する:モデルを適切な状態にする # 学習を準備する:モデルを適切な状態にする
if args.stop_text_encoder_training is None: train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
train_text_encoder = args.stop_text_encoder_training >= 0
unet.requires_grad_(True) # 念のため追加 unet.requires_grad_(True) # 念のため追加
text_encoder.requires_grad_(train_text_encoder) text_encoder.requires_grad_(train_text_encoder)
if not train_text_encoder: if not train_text_encoder:
@@ -143,6 +140,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * len(train_dataloader)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
if args.stop_text_encoder_training is None:
args.stop_text_encoder_training = args.max_train_steps + 1 # do not stop until end
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = diffusers.optimization.get_scheduler(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps) args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
@@ -176,6 +176,8 @@ def train(args):
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

View File

@@ -72,7 +72,7 @@ identifierとclassを使い、たとえば「shs dog」などでモデルを学
※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。 ※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。
``` ```
accelerate launch --num_cpu_threads_per_process 8 train_db.py accelerate launch --num_cpu_threads_per_process 1 train_db.py
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ> --pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
--train_data_dir=<学習用データのディレクトリ> --train_data_dir=<学習用データのディレクトリ>
--reg_data_dir=<正則化画像のディレクトリ> --reg_data_dir=<正則化画像のディレクトリ>
@@ -89,7 +89,7 @@ accelerate launch --num_cpu_threads_per_process 8 train_db.py
--gradient_checkpointing --gradient_checkpointing
``` ```
num_cpu_threads_per_processにはCPUコア数を指定するとよいようです。 num_cpu_threads_per_processには通常は1を指定するとよいようです。
pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル.ckptまたは.safetensors、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになりますsave_model_asオプションで変更できます pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル.ckptまたは.safetensors、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID"stabilityai/stable-diffusion-2"などが指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになりますsave_model_asオプションで変更できます
@@ -159,7 +159,7 @@ v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述
![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png) ![image](https://user-images.githubusercontent.com/52813779/210776915-061d79c3-6582-42c2-8884-8b91d2f07313.png)
各yamlファイルは[https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion](Stability AIのSD2.0のリポジトリ)にあります。 各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
# その他の学習オプション # その他の学習オプション

View File

@@ -3,6 +3,9 @@ import argparse
import gc import gc
import math import math
import os import os
import random
import time
import json
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@@ -18,7 +21,23 @@ def collate_fn(examples):
return examples[0] return examples[0]
def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if args.network_train_unet_only:
logs["lr/unet"] = lr_scheduler.get_last_lr()[0]
elif args.network_train_text_encoder_only:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
else:
logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0]
logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder
return logs
def train(args): def train(args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
train_util.verify_training_args(args) train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True) train_util.prepare_dataset_args(args, True)
@@ -88,7 +107,8 @@ def train(args):
key, value = net_arg.split('=') key, value = net_arg.split('=')
net_kwargs[key] = value net_kwargs[key] = value
network = network_module.create_network(1.0, args.network_dim, vae, text_encoder, unet, **net_kwargs) # if a new network is added in future, add if ~ then blocks for each network (;'∀')
network = network_module.create_network(1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs)
if network is None: if network is None:
return return
@@ -166,6 +186,9 @@ def train(args):
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
unet.train() unet.train()
text_encoder.train() text_encoder.train()
# set top parameter requires_grad = True for gradient checkpointing works
text_encoder.text_model.embeddings.requires_grad_(True)
else: else:
unet.eval() unet.eval()
text_encoder.eval() text_encoder.eval()
@@ -189,6 +212,8 @@ def train(args):
# epoch数を計算する # epoch数を計算する
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
# 学習する # 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
@@ -203,21 +228,26 @@ def train(args):
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
metadata = { metadata = {
"ss_session_id": session_id, # random integer indicating which group of epochs the model came from
"ss_training_started_at": training_started_at, # unix timestamp
"ss_output_name": args.output_name,
"ss_learning_rate": args.learning_rate, "ss_learning_rate": args.learning_rate,
"ss_text_encoder_lr": args.text_encoder_lr, "ss_text_encoder_lr": args.text_encoder_lr,
"ss_unet_lr": args.unet_lr, "ss_unet_lr": args.unet_lr,
"ss_num_train_images": train_dataset.num_train_images, # includes repeating TODO more detailed data "ss_num_train_images": train_dataset.num_train_images, # includes repeating
"ss_num_reg_images": train_dataset.num_reg_images, "ss_num_reg_images": train_dataset.num_reg_images,
"ss_num_batches_per_epoch": len(train_dataloader), "ss_num_batches_per_epoch": len(train_dataloader),
"ss_num_epochs": num_train_epochs, "ss_num_epochs": num_train_epochs,
"ss_batch_size_per_device": args.train_batch_size, "ss_batch_size_per_device": args.train_batch_size,
"ss_total_batch_size": total_batch_size, "ss_total_batch_size": total_batch_size,
"ss_gradient_checkpointing": args.gradient_checkpointing,
"ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps,
"ss_max_train_steps": args.max_train_steps, "ss_max_train_steps": args.max_train_steps,
"ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_warmup_steps": args.lr_warmup_steps,
"ss_lr_scheduler": args.lr_scheduler, "ss_lr_scheduler": args.lr_scheduler,
"ss_network_module": args.network_module, "ss_network_module": args.network_module,
"ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim
"ss_network_alpha": args.network_alpha, # some networks may not use this value
"ss_mixed_precision": args.mixed_precision, "ss_mixed_precision": args.mixed_precision,
"ss_full_fp16": bool(args.full_fp16), "ss_full_fp16": bool(args.full_fp16),
"ss_v2": bool(args.v2), "ss_v2": bool(args.v2),
@@ -229,10 +259,15 @@ def train(args):
"ss_random_crop": bool(args.random_crop), "ss_random_crop": bool(args.random_crop),
"ss_shuffle_caption": bool(args.shuffle_caption), "ss_shuffle_caption": bool(args.shuffle_caption),
"ss_cache_latents": bool(args.cache_latents), "ss_cache_latents": bool(args.cache_latents),
"ss_enable_bucket": bool(train_dataset.enable_bucket), # TODO move to BaseDataset from DB/FT "ss_enable_bucket": bool(train_dataset.enable_bucket),
"ss_min_bucket_reso": args.min_bucket_reso, # TODO get from dataset "ss_min_bucket_reso": train_dataset.min_bucket_reso,
"ss_max_bucket_reso": args.max_bucket_reso, "ss_max_bucket_reso": train_dataset.max_bucket_reso,
"ss_seed": args.seed "ss_seed": args.seed,
"ss_keep_tokens": args.keep_tokens,
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
"ss_training_comment": args.training_comment # will not be updated after training
} }
# uncomment if another network is added # uncomment if another network is added
@@ -243,6 +278,7 @@ def train(args):
sd_model_name = args.pretrained_model_name_or_path sd_model_name = args.pretrained_model_name_or_path
if os.path.exists(sd_model_name): if os.path.exists(sd_model_name):
metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name)
metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name)
sd_model_name = os.path.basename(sd_model_name) sd_model_name = os.path.basename(sd_model_name)
metadata["ss_sd_model_name"] = sd_model_name metadata["ss_sd_model_name"] = sd_model_name
@@ -250,6 +286,7 @@ def train(args):
vae_name = args.vae vae_name = args.vae
if os.path.exists(vae_name): if os.path.exists(vae_name):
metadata["ss_vae_hash"] = train_util.model_hash(vae_name) metadata["ss_vae_hash"] = train_util.model_hash(vae_name)
metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name)
vae_name = os.path.basename(vae_name) vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name metadata["ss_vae_name"] = vae_name
@@ -330,20 +367,20 @@ def train(args):
global_step += 1 global_step += 1
current_loss = loss.detach().item() current_loss = loss.detach().item()
if args.logging_dir is not None:
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
accelerator.log(logs, step=global_step)
loss_total += current_loss loss_total += current_loss
avr_loss = loss_total / (step+1) avr_loss = loss_total / (step+1)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
if args.logging_dir is not None:
logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"epoch_loss": loss_total / len(train_dataloader)} logs = {"loss/epoch": loss_total / len(train_dataloader)}
accelerator.log(logs, step=epoch+1) accelerator.log(logs, step=epoch+1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -364,9 +401,9 @@ def train(args):
print(f"removing old checkpoint: {old_ckpt_file}") print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state: if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no) train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
# end of epoch # end of epoch
@@ -403,8 +440,8 @@ if __name__ == '__main__':
train_util.add_training_arguments(parser, True) train_util.add_training_arguments(parser, True)
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"], parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt") help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors")
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
@@ -414,11 +451,15 @@ if __name__ == '__main__':
parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール')
parser.add_argument("--network_dim", type=int, default=None, parser.add_argument("--network_dim", type=int, default=None,
help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)')
parser.add_argument("--network_alpha", type=float, default=1,
help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定')
parser.add_argument("--network_args", type=str, default=None, nargs='*', parser.add_argument("--network_args", type=str, default=None, nargs='*',
help='additional argmuments for network (key=value) / ネットワークへの追加の引数') help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する")
parser.add_argument("--network_train_text_encoder_only", action="store_true", parser.add_argument("--network_train_text_encoder_only", action="store_true",
help="only training Text Encoder part / Text Encoder関連部分のみ学習する") help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
parser.add_argument("--training_comment", type=str, default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
args = parser.parse_args() args = parser.parse_args()
train(args) train(args)

View File

@@ -24,7 +24,7 @@ DreamBoothの手法identifiersksなどとclass、オプションで正
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。 [DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。 学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション``network_dim``や``network_alpha``など)を追加してください。
ほぼすべてのオプションStable Diffusionのモデル保存関係を除くが使えますが、stop_text_encoder_trainingはサポートしていません。 ほぼすべてのオプションStable Diffusionのモデル保存関係を除くが使えますが、stop_text_encoder_trainingはサポートしていません。
@@ -32,7 +32,7 @@ DreamBoothの手法identifiersksなどとclass、オプションで正
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。 [fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプションモデル保存関係を除くがそのまま使えます。 学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプションモデル保存関係を除くがそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション``network_dim``や``network_alpha``など)を追加してください。
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時またはキャッシュ時にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。 なお「latentsの事前取得」は行わなくても動作します。VAEから学習時またはキャッシュ時にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
@@ -45,7 +45,7 @@ train_network.pyでは--network_moduleオプションに、学習対象のモジ
以下はコマンドラインの例ですDreamBooth手法 以下はコマンドラインの例ですDreamBooth手法
``` ```
accelerate launch --num_cpu_threads_per_process 12 train_network.py accelerate launch --num_cpu_threads_per_process 1 train_network.py
--pretrained_model_name_or_path=..\models\model.ckpt --pretrained_model_name_or_path=..\models\model.ckpt
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1 --train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0 --reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
@@ -60,7 +60,9 @@ accelerate launch --num_cpu_threads_per_process 12 train_network.py
その他、以下のオプションが指定できます。 その他、以下のオプションが指定できます。
* --network_dim * --network_dim
* LoRAの次元数を指定します(``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。 * LoRAのRANKを指定します(``--networkdim=4``など。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
* --network_alpha
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
* --network_weights * --network_weights
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。 * 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
* --network_train_unet_only * --network_train_unet_only
@@ -126,7 +128,7 @@ python networks\merge_lora.py
--ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。 --ratiosにそれぞれのモデルの比率どのくらい重みを元モデルに反映するかを0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。 v1で学習したLoRAとv2で学習したLoRA、rank次元数や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
### その他のオプション ### その他のオプション
@@ -138,7 +140,7 @@ v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマ
## 当リポジトリ内の画像生成スクリプトで生成する ## 当リポジトリ内の画像生成スクリプトで生成する
gen_img_diffusers.pyに、--network_module、--network_weights、--network_dim省略可の各オプションを追加してください。意味は学習時と同様です。 gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
--network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。 --network_mulオプションで0~1.0の数値を指定すると、LoRAの適用率を変えられます。