mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'main' into caption-frequency-metadata
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
## リポジトリについて
|
## リポジトリについて
|
||||||
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
Stable Diffusionの学習、画像生成、その他のスクリプトを入れたリポジトリです。
|
||||||
|
|
||||||
[README in English](./README.md)
|
[README in English](./README.md) ←更新情報はこちらにあります
|
||||||
|
|
||||||
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
GUIやPowerShellスクリプトなど、より使いやすくする機能が[bmaltais氏のリポジトリ](https://github.com/bmaltais/kohya_ss)で提供されています(英語です)のであわせてご覧ください。bmaltais氏に感謝します。
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
|||||||
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
||||||
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
|
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
|
||||||
* [LoRAの学習について](./train_network_README-ja.md)
|
* [LoRAの学習について](./train_network_README-ja.md)
|
||||||
|
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
||||||
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||||
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
* note.com [モデル変換スクリプト](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||||
|
|
||||||
@@ -103,6 +104,10 @@ accelerate configの質問には以下のように答えてください。(bf1
|
|||||||
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
※場合によって ``ValueError: fp16 mixed precision requires a GPU`` というエラーが出ることがあるようです。この場合、6番目の質問(
|
||||||
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
``What GPU(s) (by id) should be used for training on this machine as a comma-separated list? [all]:``)に「0」と答えてください。(id `0`のGPUが使われます。)
|
||||||
|
|
||||||
|
### PyTorchとxformersのバージョンについて
|
||||||
|
|
||||||
|
他のバージョンでは学習がうまくいかない場合があるようです。特に他の理由がなければ指定のバージョンをお使いください。
|
||||||
|
|
||||||
## アップグレード
|
## アップグレード
|
||||||
|
|
||||||
新しいリリースがあった場合、以下のコマンドで更新できます。
|
新しいリリースがあった場合、以下のコマンドで更新できます。
|
||||||
|
|||||||
47
README.md
47
README.md
@@ -4,33 +4,28 @@ This repository contains training, generation and utility scripts for Stable Dif
|
|||||||
|
|
||||||
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
|
__Stable Diffusion web UI now seems to support LoRA trained by ``sd-scripts``.__ Thank you for great work!!!
|
||||||
|
|
||||||
Note: Currently the LoRA models trained by release v0.4.0 does not seem to be supported. If you use Web UI native LoRA support, please use release 0.3.2 for now. The LoRA models for SD 2.x is not supported too in Web UI.
|
Note: The LoRA models for SD 2.x is not supported too in Web UI.
|
||||||
|
|
||||||
- Release v0.4.0: 22 Jan. 2023
|
- 26 Jan. 2023, 2023/1/26
|
||||||
- Add ``--network_alpha`` option to specify ``alpha`` value to prevent underflows for stable training. Thanks to CCRcmcpe!
|
- Add Textual Inversion training. Documentation is [here](./train_ti_README-ja.md) (in Japanese.)
|
||||||
- Details of the issue are described in https://github.com/kohya-ss/sd-webui-additional-networks/issues/49 .
|
- Textual Inversionの学習をサポートしました。ドキュメントは[こちら](./train_ti_README-ja.md)。
|
||||||
- The default value is ``1``, scale ``1 / rank (or dimension)``. Set same value as ``network_dim`` for same behavior to old version.
|
- 24 Jan. 2023, 2023/1/24
|
||||||
- LoRA with a large dimension (rank) seems to require a higher learning rate with ``alpha=1`` (e.g. 1e-3 for 128-dim, still investigating).
|
- Change the default save format to ``.safetensors`` for ``train_network.py``.
|
||||||
- For generating images in Web UI, __the latest version of the extension ``sd-webui-additional-networks`` (v0.3.0 or later) is required for the models trained with this release or later.__
|
- Add ``--save_n_epoch_ratio`` option to specify how often to save. Thanks to forestsource!
|
||||||
- Add logging for the learning rate for U-Net and Text Encoder independently, and for running average epoch loss. Thanks to mgz-dev!
|
- For example, if 5 is specified, 5 (or 6) files will be saved in training.
|
||||||
- Add more metadata such as dataset/reg image dirs, session ID, output name etc... See https://github.com/kohya-ss/sd-scripts/pull/77 for details. Thanks to space-nuko!
|
- Add feature to pre-calculate hash to reduce loading time in the extension. Thanks to space-nuko!
|
||||||
- __Now the metadata includes the folder name (the basename of the folder contains image files, not fullpath).__ If you do not want it, disable metadata storing with ``--no_metadata`` option.
|
- Add bucketing metadata. Thanks to space-nuko!
|
||||||
- Add ``--training_comment`` option. You can specify an arbitrary string and refer to it by the extension.
|
- Fix an error with bf16 model in ``gen_img_diffusers.py``.
|
||||||
|
- ``train_network.py`` のモデル保存形式のデフォルトを ``.safetensors`` に変更しました。
|
||||||
|
- モデルを保存する頻度を指定する ``--save_n_epoch_ratio`` オプションが追加されました。forestsource氏に感謝します。
|
||||||
|
- たとえば 5 を指定すると、学習終了までに合計で5個(または6個)のファイルが保存されます。
|
||||||
|
- 拡張でモデル読み込み時間を短縮するためのハッシュ事前計算の機能を追加しました。space-nuko氏に感謝します。
|
||||||
|
- メタデータにbucket情報が追加されました。space-nuko氏に感謝します。
|
||||||
|
- ``gen_img_diffusers.py`` でbf16形式のモデルを読み込んだときのエラーを修正しました。
|
||||||
|
|
||||||
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
|
Stable Diffusion web UI本体で当リポジトリで学習したLoRAモデルによる画像生成がサポートされたようです。
|
||||||
|
|
||||||
注:現時点ではversion 0.4.0で学習したモデルはサポートされないようです。Web UI本体の生成機能を使う場合には、version 0.3.2を引き続きご利用ください。またSD2.x用のLoRAモデルもサポートされないようです。
|
注:SD2.x用のLoRAモデルはサポートされないようです。
|
||||||
|
|
||||||
- Release 0.4.0: 2023/1/22
|
|
||||||
- アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定する、``--network_alpha`` オプションを追加しました。CCRcmcpe 氏に感謝します。
|
|
||||||
- 問題の詳細はこちらをご覧ください: https://github.com/kohya-ss/sd-webui-additional-networks/issues/49
|
|
||||||
- デフォルト値は ``1`` で、LoRAの計算結果を ``1 / rank (dimension・次元数)`` 倍します(つまり小さくなります。これにより同じ効果を出すために必要なLoRAの重みの変化が大きくなるため、アンダーフローが避けられるようになります)。``network_dim`` と同じ値を指定すると旧バージョンと同じ動作になります。
|
|
||||||
- ``alpha=1``の場合、次元数(rank)の多いLoRAモジュールでは学習率を高めにしたほうが良いようです(128次元で1e-3など)。
|
|
||||||
- __このバージョンのスクリプトで学習したモデルをWeb UIで使うためには ``sd-webui-additional-networks`` 拡張の最新版(v0.3.0以降)が必要となります。__
|
|
||||||
- U-Net と Text Encoder のそれぞれの学習率、エポックの平均lossをログに記録するようになりました。mgz-dev 氏に感謝します。
|
|
||||||
- 画像ディレクトリ、セッションID、出力名などいくつかの項目がメタデータに追加されました(詳細は https://github.com/kohya-ss/sd-scripts/pull/77 を参照)。space-nuko氏に感謝します。
|
|
||||||
- __メタデータにフォルダ名が含まれるようになりました(画像を含むフォルダの名前のみで、フルパスではありません)。__ もし望まない場合には ``--no_metadata`` オプションでメタデータの記録を止めてください。
|
|
||||||
- ``--training_comment`` オプションを追加しました。任意の文字列を指定でき、Web UI拡張から参照できます。
|
|
||||||
|
|
||||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||||
@@ -63,6 +58,7 @@ All documents are in Japanese currently, and CUI based.
|
|||||||
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
||||||
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||||
* [training LoRA](./train_network_README-ja.md)
|
* [training LoRA](./train_network_README-ja.md)
|
||||||
|
* [training Textual Inversion](./train_ti_README-ja.md)
|
||||||
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||||
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
* note.com [Model conversion](https://note.com/kohya_ss/n/n374f316fe4ad)
|
||||||
|
|
||||||
@@ -120,6 +116,11 @@ note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
|||||||
|
|
||||||
(Single GPU with id `0` will be used.)
|
(Single GPU with id `0` will be used.)
|
||||||
|
|
||||||
|
### about PyTorch and xformers
|
||||||
|
|
||||||
|
Other versions of PyTorch and xformers seem to have problems with training.
|
||||||
|
If there is no other reason, please install the specified version.
|
||||||
|
|
||||||
## Upgrade
|
## Upgrade
|
||||||
|
|
||||||
When a new release comes out you can upgrade your repo with the following command:
|
When a new release comes out you can upgrade your repo with the following command:
|
||||||
|
|||||||
@@ -200,6 +200,8 @@ def train(args):
|
|||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|||||||
@@ -324,7 +324,7 @@ __※引数を都度書き換えて、別のメタデータファイルに書き
|
|||||||
## 学習の実行
|
## 学習の実行
|
||||||
たとえば以下のように実行します。以下は省メモリ化のための設定です。
|
たとえば以下のように実行します。以下は省メモリ化のための設定です。
|
||||||
```
|
```
|
||||||
accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
|
accelerate launch --num_cpu_threads_per_process 1 fine_tune.py
|
||||||
--pretrained_model_name_or_path=model.ckpt
|
--pretrained_model_name_or_path=model.ckpt
|
||||||
--in_json meta_lat.json
|
--in_json meta_lat.json
|
||||||
--train_data_dir=train_data
|
--train_data_dir=train_data
|
||||||
@@ -336,7 +336,7 @@ accelerate launch --num_cpu_threads_per_process 8 fine_tune.py
|
|||||||
--save_every_n_epochs=4
|
--save_every_n_epochs=4
|
||||||
```
|
```
|
||||||
|
|
||||||
accelerateのnum_cpu_threads_per_processにはCPUのコア数を指定するとよいようです。
|
accelerateのnum_cpu_threads_per_processには通常は1を指定するとよいようです。
|
||||||
|
|
||||||
pretrained_model_name_or_pathに学習対象のモデルを指定します(Stable DiffusionのcheckpointかDiffusersのモデル)。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています(拡張子で自動判定)。
|
pretrained_model_name_or_pathに学習対象のモデルを指定します(Stable DiffusionのcheckpointかDiffusersのモデル)。Stable Diffusionのcheckpointは.ckptと.safetensorsに対応しています(拡張子で自動判定)。
|
||||||
|
|
||||||
|
|||||||
@@ -470,6 +470,9 @@ class PipelineLike():
|
|||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self.safety_checker = None
|
self.safety_checker = None
|
||||||
|
|
||||||
|
# Textual Inversion
|
||||||
|
self.token_replacements = {}
|
||||||
|
|
||||||
# CLIP guidance
|
# CLIP guidance
|
||||||
self.clip_guidance_scale = clip_guidance_scale
|
self.clip_guidance_scale = clip_guidance_scale
|
||||||
self.clip_image_guidance_scale = clip_image_guidance_scale
|
self.clip_image_guidance_scale = clip_image_guidance_scale
|
||||||
@@ -484,6 +487,19 @@ class PipelineLike():
|
|||||||
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
self.vgg16_feat_model = torchvision.models._utils.IntermediateLayerGetter(vgg16_model.features, return_layers=return_layers)
|
||||||
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
self.vgg16_normalize = transforms.Normalize(mean=VGG16_IMAGE_MEAN, std=VGG16_IMAGE_STD)
|
||||||
|
|
||||||
|
# Textual Inversion
|
||||||
|
def add_token_replacement(self, target_token_id, rep_token_ids):
|
||||||
|
self.token_replacements[target_token_id] = rep_token_ids
|
||||||
|
|
||||||
|
def replace_token(self, tokens):
|
||||||
|
new_tokens = []
|
||||||
|
for token in tokens:
|
||||||
|
if token in self.token_replacements:
|
||||||
|
new_tokens.extend(self.token_replacements[token])
|
||||||
|
else:
|
||||||
|
new_tokens.append(token)
|
||||||
|
return new_tokens
|
||||||
|
|
||||||
# region xformersとか使う部分:独自に書き換えるので関係なし
|
# region xformersとか使う部分:独自に書き換えるので関係なし
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
@@ -1507,6 +1523,9 @@ def get_prompts_with_weights(pipe: PipelineLike, prompt: List[str], max_length:
|
|||||||
for word, weight in texts_and_weights:
|
for word, weight in texts_and_weights:
|
||||||
# tokenize and discard the starting and the ending token
|
# tokenize and discard the starting and the ending token
|
||||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||||
|
|
||||||
|
token = pipe.replace_token(token)
|
||||||
|
|
||||||
text_token += token
|
text_token += token
|
||||||
# copy the weight by length of token
|
# copy the weight by length of token
|
||||||
text_weight += [weight] * len(token)
|
text_weight += [weight] * len(token)
|
||||||
@@ -2039,6 +2058,44 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Textual Inversionを処理する
|
||||||
|
if args.textual_inversion_embeddings:
|
||||||
|
token_ids_embeds = []
|
||||||
|
for embeds_file in args.textual_inversion_embeddings:
|
||||||
|
if model_util.is_safetensors(embeds_file):
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
data = load_file(embeds_file)
|
||||||
|
else:
|
||||||
|
data = torch.load(embeds_file, map_location="cpu")
|
||||||
|
|
||||||
|
embeds = next(iter(data.values()))
|
||||||
|
if type(embeds) != torch.Tensor:
|
||||||
|
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {embeds_file}")
|
||||||
|
|
||||||
|
num_vectors_per_token = embeds.size()[0]
|
||||||
|
token_string = os.path.splitext(os.path.basename(embeds_file))[0]
|
||||||
|
token_strings = [token_string] + [f"{token_string}{i+1}" for i in range(num_vectors_per_token - 1)]
|
||||||
|
|
||||||
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
|
assert num_added_tokens == num_vectors_per_token, f"tokenizer has same word to token string (filename). please rename the file / 指定した名前(ファイル名)のトークンが既に存在します。ファイルをリネームしてください: {embeds_file}"
|
||||||
|
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
|
print(f"Textual Inversion embeddings `{token_string}` loaded. Tokens are added: {token_ids}")
|
||||||
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
|
if num_vectors_per_token > 1:
|
||||||
|
pipe.add_token_replacement(token_ids[0], token_ids)
|
||||||
|
|
||||||
|
token_ids_embeds.append((token_ids, embeds))
|
||||||
|
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
for token_ids, embeds in token_ids_embeds:
|
||||||
|
for token_id, embed in zip(token_ids, embeds):
|
||||||
|
token_embeds[token_id] = embed
|
||||||
|
|
||||||
# promptを取得する
|
# promptを取得する
|
||||||
if args.from_file is not None:
|
if args.from_file is not None:
|
||||||
print(f"reading prompts from {args.from_file}")
|
print(f"reading prompts from {args.from_file}")
|
||||||
@@ -2157,8 +2214,8 @@ def main(args):
|
|||||||
os.makedirs(args.outdir, exist_ok=True)
|
os.makedirs(args.outdir, exist_ok=True)
|
||||||
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
max_embeddings_multiples = 1 if args.max_embeddings_multiples is None else args.max_embeddings_multiples
|
||||||
|
|
||||||
for iter in range(args.n_iter):
|
for gen_iter in range(args.n_iter):
|
||||||
print(f"iteration {iter+1}/{args.n_iter}")
|
print(f"iteration {gen_iter+1}/{args.n_iter}")
|
||||||
iter_seed = random.randint(0, 0x7fffffff)
|
iter_seed = random.randint(0, 0x7fffffff)
|
||||||
|
|
||||||
# バッチ処理の関数
|
# バッチ処理の関数
|
||||||
@@ -2527,6 +2584,8 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
parser.add_argument("--network_mul", type=float, default=None, nargs='*', help='Hypernetwork multiplier / Hypernetworkの効果の倍率')
|
||||||
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
parser.add_argument("--network_args", type=str, default=None, nargs='*',
|
||||||
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
help='additional argmuments for network (key=value) / ネットワークへの追加の引数')
|
||||||
|
parser.add_argument("--textual_inversion_embeddings", type=str, default=None, nargs='*',
|
||||||
|
help='Embeddings files of Textual Inversion / Textual Inversionのembeddings')
|
||||||
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
parser.add_argument("--clip_skip", type=int, default=None, help='layer number from bottom to use in CLIP / CLIPの後ろからn層目の出力を使う')
|
||||||
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
parser.add_argument("--max_embeddings_multiples", type=int, default=None,
|
||||||
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
help='max embeding multiples, max token length is 75 * multiples / トークン長をデフォルトの何倍とするか 75*この値 がトークン長となる')
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -25,6 +26,7 @@ from PIL import Image
|
|||||||
import cv2
|
import cv2
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
import safetensors.torch
|
||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
|
|
||||||
@@ -86,6 +88,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.min_bucket_reso = None
|
self.min_bucket_reso = None
|
||||||
self.max_bucket_reso = None
|
self.max_bucket_reso = None
|
||||||
self.tag_frequency = {}
|
self.tag_frequency = {}
|
||||||
|
self.bucket_info = None
|
||||||
|
|
||||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||||
|
|
||||||
@@ -111,9 +114,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
self.image_data: dict[str, ImageInfo] = {}
|
self.image_data: dict[str, ImageInfo] = {}
|
||||||
|
|
||||||
|
self.replacements = {}
|
||||||
|
|
||||||
def disable_token_padding(self):
|
def disable_token_padding(self):
|
||||||
self.token_padding_disabled = True
|
self.token_padding_disabled = True
|
||||||
|
|
||||||
|
def add_replacement(self, str_from, str_to):
|
||||||
|
self.replacements[str_from] = str_to
|
||||||
|
|
||||||
def process_caption(self, caption):
|
def process_caption(self, caption):
|
||||||
if self.shuffle_caption:
|
if self.shuffle_caption:
|
||||||
tokens = caption.strip().split(",")
|
tokens = caption.strip().split(",")
|
||||||
@@ -126,6 +134,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
random.shuffle(tokens)
|
random.shuffle(tokens)
|
||||||
tokens = keep_tokens + tokens
|
tokens = keep_tokens + tokens
|
||||||
caption = ",".join(tokens).strip()
|
caption = ",".join(tokens).strip()
|
||||||
|
|
||||||
|
for str_from, str_to in self.replacements.items():
|
||||||
|
if str_from == "":
|
||||||
|
# replace all
|
||||||
|
if type(str_to) == list:
|
||||||
|
caption = random.choice(str_to)
|
||||||
|
else:
|
||||||
|
caption = str_to
|
||||||
|
else:
|
||||||
|
caption = caption.replace(str_from, str_to)
|
||||||
|
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
def get_input_ids(self, caption):
|
def get_input_ids(self, caption):
|
||||||
@@ -218,11 +237,17 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
self.buckets[bucket_index].append(image_info.image_key)
|
self.buckets[bucket_index].append(image_info.image_key)
|
||||||
|
|
||||||
if self.enable_bucket:
|
if self.enable_bucket:
|
||||||
|
self.bucket_info = {"buckets": {}}
|
||||||
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
print("number of images (including repeats) / 各bucketの画像枚数(繰り返し回数を含む)")
|
||||||
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
for i, (reso, img_keys) in enumerate(zip(bucket_resos, self.buckets)):
|
||||||
|
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(img_keys)}
|
||||||
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
print(f"bucket {i}: resolution {reso}, count: {len(img_keys)}")
|
||||||
|
|
||||||
img_ar_errors = np.array(img_ar_errors)
|
img_ar_errors = np.array(img_ar_errors)
|
||||||
print(f"mean ar error (without repeats): {np.mean(np.abs(img_ar_errors))}")
|
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||||
|
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||||
|
print(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||||
|
|
||||||
|
|
||||||
# 参照用indexを作る
|
# 参照用indexを作る
|
||||||
self.buckets_indices: list(BucketBatchIndex) = []
|
self.buckets_indices: list(BucketBatchIndex) = []
|
||||||
@@ -609,7 +634,7 @@ class FineTuningDataset(BaseDataset):
|
|||||||
else:
|
else:
|
||||||
# わりといい加減だがいい方法が思いつかん
|
# わりといい加減だがいい方法が思いつかん
|
||||||
abs_path = glob_images(train_data_dir, image_key)
|
abs_path = glob_images(train_data_dir, image_key)
|
||||||
assert len(abs_path) >= 1, f"no image / 画像がありません: {abs_path}"
|
assert len(abs_path) >= 1, f"no image / 画像がありません: {image_key}"
|
||||||
abs_path = abs_path[0]
|
abs_path = abs_path[0]
|
||||||
|
|
||||||
caption = img_md.get('caption')
|
caption = img_md.get('caption')
|
||||||
@@ -716,15 +741,17 @@ class FineTuningDataset(BaseDataset):
|
|||||||
return npz_file_norm, npz_file_flip
|
return npz_file_norm, npz_file_flip
|
||||||
|
|
||||||
|
|
||||||
def debug_dataset(train_dataset):
|
def debug_dataset(train_dataset, show_input_ids=False):
|
||||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||||
print("Escape for exit. / Escキーで中断、終了します")
|
print("Escape for exit. / Escキーで中断、終了します")
|
||||||
k = 0
|
k = 0
|
||||||
for example in train_dataset:
|
for example in train_dataset:
|
||||||
if example['latents'] is not None:
|
if example['latents'] is not None:
|
||||||
print("sample has latents from npz file")
|
print("sample has latents from npz file")
|
||||||
for j, (ik, cap, lw) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'])):
|
for j, (ik, cap, lw, iid) in enumerate(zip(example['image_keys'], example['captions'], example['loss_weights'], example['input_ids'])):
|
||||||
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
|
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, caption: "{cap}", loss weight: {lw}')
|
||||||
|
if show_input_ids:
|
||||||
|
print(f"input ids: {iid}")
|
||||||
if example['images'] is not None:
|
if example['images'] is not None:
|
||||||
im = example['images'][j]
|
im = example['images'][j]
|
||||||
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
|
||||||
@@ -800,6 +827,49 @@ def calculate_sha256(filename):
|
|||||||
return hash_sha256.hexdigest()
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def precalculate_safetensors_hashes(tensors, metadata):
|
||||||
|
"""Precalculate the model hashes needed by sd-webui-additional-networks to
|
||||||
|
save time on indexing the model later."""
|
||||||
|
|
||||||
|
# Because writing user metadata to the file can change the result of
|
||||||
|
# sd_models.model_hash(), only retain the training metadata for purposes of
|
||||||
|
# calculating the hash, as they are meant to be immutable
|
||||||
|
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")}
|
||||||
|
|
||||||
|
bytes = safetensors.torch.save(tensors, metadata)
|
||||||
|
b = BytesIO(bytes)
|
||||||
|
|
||||||
|
model_hash = addnet_hash_safetensors(b)
|
||||||
|
legacy_hash = addnet_hash_legacy(b)
|
||||||
|
return model_hash, legacy_hash
|
||||||
|
|
||||||
|
|
||||||
|
def addnet_hash_legacy(b):
|
||||||
|
"""Old model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||||
|
m = hashlib.sha256()
|
||||||
|
|
||||||
|
b.seek(0x100000)
|
||||||
|
m.update(b.read(0x10000))
|
||||||
|
return m.hexdigest()[0:8]
|
||||||
|
|
||||||
|
|
||||||
|
def addnet_hash_safetensors(b):
|
||||||
|
"""New model hash used by sd-webui-additional-networks for .safetensors format files"""
|
||||||
|
hash_sha256 = hashlib.sha256()
|
||||||
|
blksize = 1024 * 1024
|
||||||
|
|
||||||
|
b.seek(0)
|
||||||
|
header = b.read(8)
|
||||||
|
n = int.from_bytes(header, "little")
|
||||||
|
|
||||||
|
offset = n + 8
|
||||||
|
b.seek(offset)
|
||||||
|
for chunk in iter(lambda: b.read(blksize), b""):
|
||||||
|
hash_sha256.update(chunk)
|
||||||
|
|
||||||
|
return hash_sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
# flash attention forwards and backwards
|
# flash attention forwards and backwards
|
||||||
|
|
||||||
# https://arxiv.org/abs/2205.14135
|
# https://arxiv.org/abs/2205.14135
|
||||||
@@ -1067,6 +1137,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
|
||||||
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
parser.add_argument("--save_every_n_epochs", type=int, default=None,
|
||||||
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
|
||||||
|
parser.add_argument("--save_n_epoch_ratio", type=int, default=None,
|
||||||
|
help="save checkpoint N epoch ratio (for example 5 means save at least 5 files total) / 学習中のモデルを指定のエポック割合で保存する(たとえば5を指定すると最低5個のファイルが保存される)")
|
||||||
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
parser.add_argument("--save_last_n_epochs", type=int, default=None, help="save last N checkpoints / 最大Nエポック保存する")
|
||||||
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
parser.add_argument("--save_last_n_epochs_state", type=int, default=None,
|
||||||
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
help="save last N checkpoints of state (overrides the value of --save_last_n_epochs)/ 最大Nエポックstateを保存する(--save_last_n_epochsの指定を上書きします)")
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import math
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
|
|
||||||
class LoRAModule(torch.nn.Module):
|
class LoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
@@ -31,7 +33,7 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
self.lora_up = torch.nn.Linear(lora_dim, out_dim, bias=False)
|
||||||
|
|
||||||
if type(alpha) == torch.Tensor:
|
if type(alpha) == torch.Tensor:
|
||||||
alpha = alpha.detach().numpy()
|
alpha = alpha.detach().float().numpy() # without casting, bf16 causes error
|
||||||
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
||||||
self.scale = alpha / self.lora_dim
|
self.scale = alpha / self.lora_dim
|
||||||
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
self.register_buffer('alpha', torch.tensor(alpha)) # 定数として扱える
|
||||||
@@ -221,6 +223,14 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if os.path.splitext(file)[1] == '.safetensors':
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
# Precalculate model hashes to save time on indexing
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
save_file(state_dict, file, metadata)
|
save_file(state_dict, file, metadata)
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|||||||
@@ -176,6 +176,8 @@ def train(args):
|
|||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ identifierとclassを使い、たとえば「shs dog」などでモデルを学
|
|||||||
※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。
|
※LoRA等の追加ネットワークを学習する場合のコマンドは ``train_db.py`` ではなく ``train_network.py`` となります。また追加でnetwork_\*オプションが必要となりますので、LoRAのガイドを参照してください。
|
||||||
|
|
||||||
```
|
```
|
||||||
accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||||
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
||||||
--train_data_dir=<学習用データのディレクトリ>
|
--train_data_dir=<学習用データのディレクトリ>
|
||||||
--reg_data_dir=<正則化画像のディレクトリ>
|
--reg_data_dir=<正則化画像のディレクトリ>
|
||||||
@@ -89,7 +89,7 @@ accelerate launch --num_cpu_threads_per_process 8 train_db.py
|
|||||||
--gradient_checkpointing
|
--gradient_checkpointing
|
||||||
```
|
```
|
||||||
|
|
||||||
num_cpu_threads_per_processにはCPUコア数を指定するとよいようです。
|
num_cpu_threads_per_processには通常は1を指定するとよいようです。
|
||||||
|
|
||||||
pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになります(save_model_asオプションで変更できます)。
|
pretrained_model_name_or_pathに追加学習を行う元となるモデルを指定します。Stable Diffusionのcheckpointファイル(.ckptまたは.safetensors)、Diffusersのローカルディスクにあるモデルディレクトリ、DiffusersのモデルID("stabilityai/stable-diffusion-2"など)が指定できます。学習後のモデルの保存形式はデフォルトでは元のモデルと同じになります(save_model_asオプションで変更できます)。
|
||||||
|
|
||||||
@@ -159,7 +159,7 @@ v2.xモデルでWebUIで画像生成する場合、モデルの仕様が記述
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
各yamlファイルは[https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion](Stability AIのSD2.0のリポジトリ)にあります。
|
各yamlファイルは[Stability AIのSD2.0のリポジトリ](https://github.com/Stability-AI/stablediffusion/tree/main/configs/stable-diffusion)にあります。
|
||||||
|
|
||||||
# その他の学習オプション
|
# その他の学習オプション
|
||||||
|
|
||||||
|
|||||||
@@ -212,6 +212,8 @@ def train(args):
|
|||||||
# epoch数を計算する
|
# epoch数を計算する
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
# 学習する
|
# 学習する
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
@@ -265,6 +267,7 @@ def train(args):
|
|||||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||||
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
"ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info),
|
||||||
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
"ss_tag_frequency": json.dumps(train_dataset.tag_frequency),
|
||||||
|
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||||
"ss_training_comment": args.training_comment # will not be updated after training
|
"ss_training_comment": args.training_comment # will not be updated after training
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,8 +441,8 @@ if __name__ == '__main__':
|
|||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
|
|
||||||
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない")
|
||||||
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)")
|
||||||
|
|
||||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||||
|
|
||||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extention](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||||
|
|
||||||
## 学習方法
|
## 学習方法
|
||||||
|
|
||||||
@@ -24,7 +24,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正
|
|||||||
|
|
||||||
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
|
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
|
||||||
|
|
||||||
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。
|
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
||||||
|
|
||||||
ほぼすべてのオプション(Stable Diffusionのモデル保存関係を除く)が使えますが、stop_text_encoder_trainingはサポートしていません。
|
ほぼすべてのオプション(Stable Diffusionのモデル保存関係を除く)が使えますが、stop_text_encoder_trainingはサポートしていません。
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ DreamBoothの手法(identifier(sksなど)とclass、オプションで正
|
|||||||
|
|
||||||
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
|
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
|
||||||
|
|
||||||
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。
|
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
||||||
|
|
||||||
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時(またはキャッシュ時)にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
|
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時(またはキャッシュ時)にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
|
||||||
|
|
||||||
@@ -45,7 +45,7 @@ train_network.pyでは--network_moduleオプションに、学習対象のモジ
|
|||||||
以下はコマンドラインの例です(DreamBooth手法)。
|
以下はコマンドラインの例です(DreamBooth手法)。
|
||||||
|
|
||||||
```
|
```
|
||||||
accelerate launch --num_cpu_threads_per_process 12 train_network.py
|
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||||
--pretrained_model_name_or_path=..\models\model.ckpt
|
--pretrained_model_name_or_path=..\models\model.ckpt
|
||||||
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
||||||
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
||||||
@@ -60,7 +60,9 @@ accelerate launch --num_cpu_threads_per_process 12 train_network.py
|
|||||||
その他、以下のオプションが指定できます。
|
その他、以下のオプションが指定できます。
|
||||||
|
|
||||||
* --network_dim
|
* --network_dim
|
||||||
* LoRAの次元数を指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||||
|
* --network_alpha
|
||||||
|
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
||||||
* --network_weights
|
* --network_weights
|
||||||
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
||||||
* --network_train_unet_only
|
* --network_train_unet_only
|
||||||
@@ -126,7 +128,7 @@ python networks\merge_lora.py
|
|||||||
|
|
||||||
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
--ratiosにそれぞれのモデルの比率(どのくらい重みを元モデルに反映するか)を0~1.0の数値で指定します。二つのモデルを一対一でマージす場合は、「0.5 0.5」になります。「1.0 1.0」では合計の重みが大きくなりすぎて、恐らく結果はあまり望ましくないものになると思われます。
|
||||||
|
|
||||||
v1で学習したLoRAとv2で学習したLoRA、次元数の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``の異なるLoRAはマージできません。U-NetだけのLoRAとU-Net+Text EncoderのLoRAはマージできるはずですが、結果は未知数です。
|
||||||
|
|
||||||
|
|
||||||
### その他のオプション
|
### その他のオプション
|
||||||
|
|||||||
498
train_textual_inversion.py
Normal file
498
train_textual_inversion.py
Normal file
@@ -0,0 +1,498 @@
|
|||||||
|
import importlib
|
||||||
|
import argparse
|
||||||
|
import gc
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
import diffusers
|
||||||
|
from diffusers import DDPMScheduler
|
||||||
|
|
||||||
|
import library.train_util as train_util
|
||||||
|
from library.train_util import DreamBoothDataset, FineTuningDataset
|
||||||
|
|
||||||
|
imagenet_templates_small = [
|
||||||
|
"a photo of a {}",
|
||||||
|
"a rendering of a {}",
|
||||||
|
"a cropped photo of the {}",
|
||||||
|
"the photo of a {}",
|
||||||
|
"a photo of a clean {}",
|
||||||
|
"a photo of a dirty {}",
|
||||||
|
"a dark photo of the {}",
|
||||||
|
"a photo of my {}",
|
||||||
|
"a photo of the cool {}",
|
||||||
|
"a close-up photo of a {}",
|
||||||
|
"a bright photo of the {}",
|
||||||
|
"a cropped photo of a {}",
|
||||||
|
"a photo of the {}",
|
||||||
|
"a good photo of the {}",
|
||||||
|
"a photo of one {}",
|
||||||
|
"a close-up photo of the {}",
|
||||||
|
"a rendition of the {}",
|
||||||
|
"a photo of the clean {}",
|
||||||
|
"a rendition of a {}",
|
||||||
|
"a photo of a nice {}",
|
||||||
|
"a good photo of a {}",
|
||||||
|
"a photo of the nice {}",
|
||||||
|
"a photo of the small {}",
|
||||||
|
"a photo of the weird {}",
|
||||||
|
"a photo of the large {}",
|
||||||
|
"a photo of a cool {}",
|
||||||
|
"a photo of a small {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
imagenet_style_templates_small = [
|
||||||
|
"a painting in the style of {}",
|
||||||
|
"a rendering in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"the painting in the style of {}",
|
||||||
|
"a clean painting in the style of {}",
|
||||||
|
"a dirty painting in the style of {}",
|
||||||
|
"a dark painting in the style of {}",
|
||||||
|
"a picture in the style of {}",
|
||||||
|
"a cool painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a bright painting in the style of {}",
|
||||||
|
"a cropped painting in the style of {}",
|
||||||
|
"a good painting in the style of {}",
|
||||||
|
"a close-up painting in the style of {}",
|
||||||
|
"a rendition in the style of {}",
|
||||||
|
"a nice painting in the style of {}",
|
||||||
|
"a small painting in the style of {}",
|
||||||
|
"a weird painting in the style of {}",
|
||||||
|
"a large painting in the style of {}",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(examples):
|
||||||
|
return examples[0]
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
if args.output_name is None:
|
||||||
|
args.output_name = args.token_string
|
||||||
|
use_template = args.use_object_template or args.use_style_template
|
||||||
|
|
||||||
|
train_util.verify_training_args(args)
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
|
cache_latents = args.cache_latents
|
||||||
|
use_dreambooth_method = args.in_json is None
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed)
|
||||||
|
|
||||||
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
|
||||||
|
# acceleratorを準備する
|
||||||
|
print("prepare accelerator")
|
||||||
|
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
|
||||||
|
# モデルを読み込む
|
||||||
|
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype)
|
||||||
|
|
||||||
|
# Convert the init_word to token_id
|
||||||
|
if args.init_word is not None:
|
||||||
|
init_token_id = tokenizer.encode(args.init_word, add_special_tokens=False)
|
||||||
|
assert len(
|
||||||
|
init_token_id) == 1, f"init word {args.init_word} is not converted to single token / 初期化単語が二つ以上のトークンに変換されます。別の単語を使ってください"
|
||||||
|
init_token_id = init_token_id[0]
|
||||||
|
else:
|
||||||
|
init_token_id = None
|
||||||
|
|
||||||
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
|
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||||
|
num_added_tokens = tokenizer.add_tokens(token_strings)
|
||||||
|
assert num_added_tokens == args.num_vectors_per_token, f"tokenizer has same word to token string. please use another one / 指定したargs.token_stringは既に存在します。別の単語を使ってください: {args.token_string}"
|
||||||
|
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(token_strings)
|
||||||
|
print(f"tokens are added: {token_ids}")
|
||||||
|
assert min(token_ids) == token_ids[0] and token_ids[-1] == token_ids[0] + len(token_ids) - 1, f"token ids is not ordered"
|
||||||
|
assert len(tokenizer) - 1 == token_ids[-1], f"token ids is not end of tokenize: {len(tokenizer)}"
|
||||||
|
|
||||||
|
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||||
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||||
|
|
||||||
|
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||||
|
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||||
|
if init_token_id is not None:
|
||||||
|
for token_id in token_ids:
|
||||||
|
token_embeds[token_id] = token_embeds[init_token_id]
|
||||||
|
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
|
||||||
|
# load weights
|
||||||
|
if args.weights is not None:
|
||||||
|
embeddings = load_weights(args.weights)
|
||||||
|
assert len(token_ids) == len(
|
||||||
|
embeddings), f"num_vectors_per_token is mismatch for weights / 指定した重みとnum_vectors_per_tokenの値が異なります: {len(embeddings)}"
|
||||||
|
# print(token_ids, embeddings.size())
|
||||||
|
for token_id, embedding in zip(token_ids, embeddings):
|
||||||
|
token_embeds[token_id] = embedding
|
||||||
|
# print(token_id, token_embeds[token_id].mean(), token_embeds[token_id].min())
|
||||||
|
print(f"weighs loaded")
|
||||||
|
|
||||||
|
print(f"create embeddings for {args.num_vectors_per_token} tokens, for {args.token_string}")
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if use_dreambooth_method:
|
||||||
|
print("Use DreamBooth method.")
|
||||||
|
train_dataset = DreamBoothDataset(args.train_batch_size, args.train_data_dir, args.reg_data_dir,
|
||||||
|
tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens,
|
||||||
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.prior_loss_weight,
|
||||||
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset)
|
||||||
|
else:
|
||||||
|
print("Train with captions.")
|
||||||
|
train_dataset = FineTuningDataset(args.in_json, args.train_batch_size, args.train_data_dir,
|
||||||
|
tokenizer, args.max_token_length, args.shuffle_caption, args.keep_tokens,
|
||||||
|
args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso,
|
||||||
|
args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop,
|
||||||
|
args.dataset_repeats, args.debug_dataset)
|
||||||
|
|
||||||
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
|
if use_template:
|
||||||
|
print("use template for training captions. is object: {args.use_object_template}")
|
||||||
|
templates = imagenet_templates_small if args.use_object_template else imagenet_style_templates_small
|
||||||
|
replace_to = " ".join(token_strings)
|
||||||
|
captions = []
|
||||||
|
for tmpl in templates:
|
||||||
|
captions.append(tmpl.format(replace_to))
|
||||||
|
train_dataset.add_replacement("", captions)
|
||||||
|
elif args.num_vectors_per_token > 1:
|
||||||
|
replace_to = " ".join(token_strings)
|
||||||
|
train_dataset.add_replacement(args.token_string, replace_to)
|
||||||
|
|
||||||
|
train_dataset.make_buckets()
|
||||||
|
|
||||||
|
if args.debug_dataset:
|
||||||
|
train_util.debug_dataset(train_dataset, show_input_ids=True)
|
||||||
|
return
|
||||||
|
if len(train_dataset) == 0:
|
||||||
|
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
||||||
|
return
|
||||||
|
|
||||||
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
|
||||||
|
# 学習を準備する
|
||||||
|
if cache_latents:
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
train_dataset.cache_latents(vae)
|
||||||
|
vae.to("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if args.gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# 学習に必要なクラスを準備する
|
||||||
|
print("prepare optimizer, data loader etc.")
|
||||||
|
|
||||||
|
# 8-bit Adamを使う
|
||||||
|
if args.use_8bit_adam:
|
||||||
|
try:
|
||||||
|
import bitsandbytes as bnb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
|
||||||
|
print("use 8-bit Adam optimizer")
|
||||||
|
optimizer_class = bnb.optim.AdamW8bit
|
||||||
|
else:
|
||||||
|
optimizer_class = torch.optim.AdamW
|
||||||
|
|
||||||
|
trainable_params = text_encoder.get_input_embeddings().parameters()
|
||||||
|
|
||||||
|
# betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
|
||||||
|
optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
|
||||||
|
|
||||||
|
# dataloaderを準備する
|
||||||
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
|
||||||
|
|
||||||
|
# 学習ステップ数を計算する
|
||||||
|
if args.max_train_epochs is not None:
|
||||||
|
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# lr schedulerを用意する
|
||||||
|
lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
|
||||||
|
|
||||||
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
index_no_updates = torch.arange(len(tokenizer)) < token_ids[0]
|
||||||
|
print(len(index_no_updates), torch.sum(index_no_updates))
|
||||||
|
orig_embeds_params = unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||||
|
|
||||||
|
# Freeze all parameters except for the token embeddings in text encoder
|
||||||
|
text_encoder.requires_grad_(True)
|
||||||
|
text_encoder.text_model.encoder.requires_grad_(False)
|
||||||
|
text_encoder.text_model.final_layer_norm.requires_grad_(False)
|
||||||
|
text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
|
||||||
|
# text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
|
||||||
|
|
||||||
|
unet.requires_grad_(False)
|
||||||
|
unet.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
if args.gradient_checkpointing: # according to TI example in Diffusers, train is required
|
||||||
|
unet.train()
|
||||||
|
else:
|
||||||
|
unet.eval()
|
||||||
|
|
||||||
|
if not cache_latents:
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
|
if args.full_fp16:
|
||||||
|
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||||
|
text_encoder.to(weight_dtype)
|
||||||
|
|
||||||
|
# resumeする
|
||||||
|
if args.resume is not None:
|
||||||
|
print(f"resume training from state: {args.resume}")
|
||||||
|
accelerator.load_state(args.resume)
|
||||||
|
|
||||||
|
# epoch数を計算する
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||||
|
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
||||||
|
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
||||||
|
|
||||||
|
# 学習する
|
||||||
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
print("running training / 学習開始")
|
||||||
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
|
||||||
|
print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
|
||||||
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
|
print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
|
||||||
|
print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
|
print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
||||||
|
num_train_timesteps=1000, clip_sample=False)
|
||||||
|
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
accelerator.init_trackers("textual_inversion")
|
||||||
|
|
||||||
|
for epoch in range(num_train_epochs):
|
||||||
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
|
|
||||||
|
text_encoder.train()
|
||||||
|
|
||||||
|
loss_total = 0
|
||||||
|
bef_epo_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
with accelerator.accumulate(text_encoder):
|
||||||
|
with torch.no_grad():
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
else:
|
||||||
|
# latentに変換
|
||||||
|
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||||
|
latents = latents * 0.18215
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
input_ids = batch["input_ids"].to(accelerator.device)
|
||||||
|
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, torch.float) # weight_dtype) use float instead of fp16/bf16 because text encoder is float
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
|
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
accelerator.backward(loss)
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
params_to_clip = text_encoder.get_input_embeddings().parameters()
|
||||||
|
accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
|
with torch.no_grad():
|
||||||
|
unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
|
||||||
|
|
||||||
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
if accelerator.sync_gradients:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
current_loss = loss.detach().item()
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
|
loss_total += current_loss
|
||||||
|
avr_loss = loss_total / (step+1)
|
||||||
|
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if global_step >= args.max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
||||||
|
accelerator.log(logs, step=epoch+1)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
updated_embs = unwrap_model(text_encoder).get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
d = updated_embs - bef_epo_embs
|
||||||
|
print(bef_epo_embs.size(), updated_embs.size(), d.mean(), d.min())
|
||||||
|
|
||||||
|
if args.save_every_n_epochs is not None:
|
||||||
|
model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name
|
||||||
|
|
||||||
|
def save_func():
|
||||||
|
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||||
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
print(f"saving checkpoint: {ckpt_file}")
|
||||||
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
|
||||||
|
def remove_old_func(old_epoch_no):
|
||||||
|
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||||
|
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||||
|
if os.path.exists(old_ckpt_file):
|
||||||
|
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
|
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||||
|
if saving and args.save_state:
|
||||||
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
|
# end of epoch
|
||||||
|
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
if is_main_process:
|
||||||
|
text_encoder = unwrap_model(text_encoder)
|
||||||
|
|
||||||
|
accelerator.end_training()
|
||||||
|
|
||||||
|
if args.save_state:
|
||||||
|
train_util.save_state_on_train_end(args, accelerator)
|
||||||
|
|
||||||
|
updated_embs = text_encoder.get_input_embeddings().weight[token_ids].data.detach().clone()
|
||||||
|
|
||||||
|
del accelerator # この後メモリを使うのでこれは消す
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||||
|
ckpt_name = model_name + '.' + args.save_model_as
|
||||||
|
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||||
|
|
||||||
|
print(f"save trained model to {ckpt_file}")
|
||||||
|
save_weights(ckpt_file, updated_embs, save_dtype)
|
||||||
|
print("model saved.")
|
||||||
|
|
||||||
|
|
||||||
|
def save_weights(file, updated_embs, save_dtype):
|
||||||
|
state_dict = {"emb_params": updated_embs}
|
||||||
|
|
||||||
|
if save_dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(save_dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
save_file(state_dict, file)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file) # can be loaded in Web UI
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(file):
|
||||||
|
if os.path.splitext(file)[1] == '.safetensors':
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
data = load_file(file)
|
||||||
|
else:
|
||||||
|
# compatible to Web UI's file format
|
||||||
|
data = torch.load(file, map_location='cpu')
|
||||||
|
if type(data) != dict:
|
||||||
|
raise ValueError(f"weight file is not dict / 重みファイルがdict形式ではありません: {file}")
|
||||||
|
|
||||||
|
if 'string_to_param' in data: # textual inversion embeddings
|
||||||
|
data = data['string_to_param']
|
||||||
|
if hasattr(data, '_parameters'): # support old PyTorch?
|
||||||
|
data = getattr(data, '_parameters')
|
||||||
|
|
||||||
|
emb = next(iter(data.values()))
|
||||||
|
if type(emb) != torch.Tensor:
|
||||||
|
raise ValueError(f"weight file does not contains Tensor / 重みファイルのデータがTensorではありません: {file}")
|
||||||
|
|
||||||
|
if len(emb.size()) == 1:
|
||||||
|
emb = emb.unsqueeze(0)
|
||||||
|
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
train_util.add_sd_models_arguments(parser)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True)
|
||||||
|
train_util.add_training_arguments(parser, True)
|
||||||
|
|
||||||
|
parser.add_argument("--save_model_as", type=str, default="pt", choices=[None, "ckpt", "pt", "safetensors"],
|
||||||
|
help="format to save the model (default is .pt) / モデル保存時の形式(デフォルトはpt)")
|
||||||
|
|
||||||
|
parser.add_argument("--weights", type=str, default=None,
|
||||||
|
help="embedding weights to initialize / 学習するネットワークの初期重み")
|
||||||
|
parser.add_argument("--num_vectors_per_token", type=int, default=1,
|
||||||
|
help='number of vectors per token / トークンに割り当てるembeddingsの要素数')
|
||||||
|
parser.add_argument("--token_string", type=str, default=None,
|
||||||
|
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること")
|
||||||
|
parser.add_argument("--init_word", type=str, default=None,
|
||||||
|
help="word to initialize vector / ベクトルを初期化に使用する単語、tokenizerで一語になること")
|
||||||
|
parser.add_argument("--use_object_template", action='store_true',
|
||||||
|
help="ignore caption and use default templates for object / キャプションは使わずデフォルトの物体用テンプレートで学習する")
|
||||||
|
parser.add_argument("--use_style_template", action='store_true',
|
||||||
|
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
train(args)
|
||||||
63
train_ti_README-ja.md
Normal file
63
train_ti_README-ja.md
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
## Textual Inversionの学習について
|
||||||
|
|
||||||
|
[Textual Inversion](https://textual-inversion.github.io/)です。実装に当たっては https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion を大いに参考にしました。
|
||||||
|
|
||||||
|
学習したモデルはWeb UIでもそのまま使えます。
|
||||||
|
|
||||||
|
なお恐らくSD2.xにも対応していますが現時点では未テストです。
|
||||||
|
|
||||||
|
## 学習方法
|
||||||
|
|
||||||
|
``train_textual_inversion.py`` を用います。
|
||||||
|
|
||||||
|
データの準備については ``train_network.py`` と全く同じですので、[そちらのドキュメント](./train_network_README-ja.md)を参照してください。
|
||||||
|
|
||||||
|
## オプション
|
||||||
|
|
||||||
|
以下はコマンドラインの例です(DreamBooth手法)。
|
||||||
|
|
||||||
|
```
|
||||||
|
accelerate launch --num_cpu_threads_per_process 1 train_textual_inversion.py
|
||||||
|
--pretrained_model_name_or_path=..\models\model.ckpt
|
||||||
|
--train_data_dir=..\data\db\char1 --output_dir=..\ti_train1
|
||||||
|
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
||||||
|
--max_train_steps=400 --use_8bit_adam --xformers --mixed_precision=fp16
|
||||||
|
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
||||||
|
--token_string=mychar4 --init_word=cute --num_vectors_per_token=4
|
||||||
|
```
|
||||||
|
|
||||||
|
``--token_string`` に学習時のトークン文字列を指定します。__学習時のプロンプトは、この文字列を含むようにしてください(token_stringがmychar4なら、``mychar4 1girl`` など)__。プロンプトのこの文字列の部分が、Textual Inversionの新しいtokenに置換されて学習されます。
|
||||||
|
|
||||||
|
プロンプトにトークン文字列が含まれているかどうかは、``--debug_dataset`` で置換後のtoken idが表示されますので、以下のように ``49408`` 以降のtokenが存在するかどうかで確認できます。
|
||||||
|
|
||||||
|
```
|
||||||
|
input ids: tensor([[49406, 49408, 49409, 49410, 49411, 49412, 49413, 49414, 49415, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
|
||||||
|
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
|
||||||
|
```
|
||||||
|
|
||||||
|
tokenizerがすでに持っている単語(一般的な単語)は使用できません。
|
||||||
|
|
||||||
|
``--init_word`` にembeddingsを初期化するときのコピー元トークンの文字列を指定します。学ばせたい概念が近いものを選ぶとよいようです。二つ以上のトークンになる文字列は指定できません。
|
||||||
|
|
||||||
|
``--num_vectors_per_token`` にいくつのトークンをこの学習で使うかを指定します。多いほうが表現力が増しますが、その分多くのトークンを消費します。たとえばnum_vectors_per_token=8の場合、指定したトークン文字列は(一般的なプロンプトの77トークン制限のうち)8トークンを消費します。
|
||||||
|
|
||||||
|
|
||||||
|
その他、以下のオプションが指定できます。
|
||||||
|
|
||||||
|
* --weights
|
||||||
|
* 学習前に学習済みのembeddingsを読み込み、そこから追加で学習します。
|
||||||
|
* --use_object_template
|
||||||
|
* キャプションではなく既定の物体用テンプレート文字列(``a photo of a {}``など)で学習します。公式実装と同じになります。キャプションは無視されます。
|
||||||
|
* --use_style_template
|
||||||
|
* キャプションではなく既定のスタイル用テンプレート文字列で学習します(``a painting in the style of {}``など)。公式実装と同じになります。キャプションは無視されます。
|
||||||
|
|
||||||
|
## 当リポジトリ内の画像生成スクリプトで生成する
|
||||||
|
|
||||||
|
gen_img_diffusers.pyに、``--textual_inversion_embeddings`` オプションで学習したembeddingsファイルを指定してください(複数可)。プロンプトでembeddingsファイルのファイル名(拡張子を除く)を使うと、そのembeddingsが適用されます。
|
||||||
|
|
||||||
Reference in New Issue
Block a user