mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge branch 'dev' into dev
This commit is contained in:
@@ -16,9 +16,10 @@ GUIやPowerShellスクリプトなど、より使いやすくする機能が[bma
|
||||
|
||||
当リポジトリ内およびnote.comに記事がありますのでそちらをご覧ください(将来的にはすべてこちらへ移すかもしれません)。
|
||||
|
||||
* [学習について、共通編](./train_README-ja.md) : データ整備やオプションなど
|
||||
* [データセット設定](./config_README-ja.md)
|
||||
* [DreamBoothの学習について](./train_db_README-ja.md)
|
||||
* [fine-tuningのガイド](./fine_tune_README_ja.md):
|
||||
BLIPによるキャプショニングと、DeepDanbooruまたはWD14 taggerによるタグ付けを含みます
|
||||
* [LoRAの学習について](./train_network_README-ja.md)
|
||||
* [Textual Inversionの学習について](./train_ti_README-ja.md)
|
||||
* note.com [画像生成スクリプト](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
@@ -131,6 +132,8 @@ pip install --use-pep517 --upgrade -r requirements.txt
|
||||
|
||||
LoRAの実装は[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を基にしたものです。感謝申し上げます。
|
||||
|
||||
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
||||
|
||||
## ライセンス
|
||||
|
||||
スクリプトのライセンスはASL 2.0ですが(Diffusersおよびcloneofsimo氏のリポジトリ由来のものも同様)、一部他のライセンスのコードを含みます。
|
||||
|
||||
109
README.md
109
README.md
@@ -26,11 +26,12 @@ The scripts are tested with PyTorch 1.12.1 and 1.13.0, Diffusers 0.10.2.
|
||||
|
||||
## Links to how-to-use documents
|
||||
|
||||
All documents are in Japanese currently, and CUI based.
|
||||
All documents are in Japanese currently.
|
||||
|
||||
* [Training guide - common](./train_README-ja.md) : data preparation, options etc...
|
||||
* [Dataset config](./config_README-ja.md)
|
||||
* [DreamBooth training guide](./train_db_README-ja.md)
|
||||
* [Step by Step fine-tuning guide](./fine_tune_README_ja.md):
|
||||
Including BLIP captioning and tagging by DeepDanbooru or WD14 tagger
|
||||
* [training LoRA](./train_network_README-ja.md)
|
||||
* [training Textual Inversion](./train_ti_README-ja.md)
|
||||
* note.com [Image generation](https://note.com/kohya_ss/n/n2693183a798e)
|
||||
@@ -110,11 +111,13 @@ Once the commands have completed successfully you should be ready to use the new
|
||||
|
||||
## Credits
|
||||
|
||||
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!!!
|
||||
The implementation for LoRA is based on [cloneofsimo's repo](https://github.com/cloneofsimo/lora). Thank you for great work!
|
||||
|
||||
The LoRA expansion to Conv2d 3x3 was initially released by cloneofsimo and its effectiveness was demonstrated at [LoCon](https://github.com/KohakuBlueleaf/LoCon) by KohakuBlueleaf. Thank you so much KohakuBlueleaf!
|
||||
|
||||
## License
|
||||
|
||||
The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's), however portions of the project are available under separate license terms:
|
||||
The majority of scripts is licensed under ASL 2.0 (including codes from Diffusers, cloneofsimo's and LoCon), however portions of the project are available under separate license terms:
|
||||
|
||||
[Memory Efficient Attention Pytorch](https://github.com/lucidrains/memory-efficient-attention-pytorch): MIT
|
||||
|
||||
@@ -124,38 +127,78 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
- 2 Mar. 2023, 2023/3/2:
|
||||
- 9 Mar. 2023, 2023/3/9:
|
||||
- There may be problems due to major changes. If you cannot revert back to the previous version when problems occur, please do not update for a while.
|
||||
- Dependencies are updated, Please [upgrade](#upgrade) the repo.
|
||||
- Add detail dataset config feature by extra config file. Thanks to fur0ut0 for this great contribution!
|
||||
- Documentation is [here](./config_README-ja.md) (only in Japanese currently.)
|
||||
- Specify ``.toml`` file with ``--dataset_config`` option.
|
||||
- The previous options for dataset can be used as is.
|
||||
- There might be a bug due to the large scale of update, please report any problems if you find.
|
||||
- Add feature to generate sample images in the middle of training for each training scripts.
|
||||
- ``--sample_every_n_steps`` and ``--sample_every_n_epochs`` options: frequency to generate.
|
||||
- ``--sample_prompts`` option: the file contains prompts (each line generates one image.)
|
||||
- The prompt is subset of ``gen_img_diffusers.py``. The prompt options ``w, h, d, l, s, n`` are supported.
|
||||
- ``--sample_sampler`` option: sampler (scheduler) for generating, such as ddim or k_euler. See help for useable samplers.
|
||||
- Add ``--tokenizer_cache_dir`` to each training and generation scripts to cache Tokenizer locally from Diffusers.
|
||||
- Scripts will support offline training/generation after caching.
|
||||
- Support letents upscaling for highres. fix, and VAE batch size in ``gen_img_diffusers.py`` (no documentation yet.)
|
||||
- Minimum metadata (module name, dim, alpha and network_args) is recorded even with `--no_metadata`, issue https://github.com/kohya-ss/sd-scripts/issues/254
|
||||
- `train_network.py` supports LoRA for Conv2d-3x3 (extended to conv2d with a kernel size not 1x1).
|
||||
- Same as a current version of [LoCon](https://github.com/KohakuBlueleaf/LoCon). __Thank you very much KohakuBlueleaf for your help!__
|
||||
- LoCon will be enhanced in the future. Compatibility for future versions is not guaranteed.
|
||||
- Specify `--network_args` option like: `--network_args "conv_dim=4" "conv_alpha=1"`
|
||||
- [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) version 0.5.0 or later is required to use 'LoRA for Conv2d-3x3' in Stable Diffusion web UI.
|
||||
- __Stable Diffusion web UI built-in LoRA does not support 'LoRA for Conv2d-3x3' now. Consider carefully whether or not to use it.__
|
||||
- Merging/extracting scripts also support LoRA for Conv2d-3x3.
|
||||
- Free CUDA memory after sample generation to reduce VRAM usage, issue https://github.com/kohya-ss/sd-scripts/issues/260
|
||||
- Empty caption doesn't cause error now, issue https://github.com/kohya-ss/sd-scripts/issues/258
|
||||
- Fix sample generation is crashing in Textual Inversion training when using templates, or if height/width is not divisible by 8.
|
||||
- Update documents (Japanese only).
|
||||
|
||||
- 大きく変更したため不具合があるかもしれません。問題が起きた時にスクリプトを前のバージョンに戻せない場合は、しばらく更新を控えてください。
|
||||
- ライブラリを更新しました。[アップグレード](https://github.com/kohya-ss/sd-scripts/blob/main/README-ja.md#%E3%82%A2%E3%83%83%E3%83%97%E3%82%B0%E3%83%AC%E3%83%BC%E3%83%89)に従って更新してください。
|
||||
- 設定ファイルによるデータセット定義機能を追加しました。素晴らしいPRを提供していただいた fur0ut0 氏に感謝します。
|
||||
- ドキュメントは[こちら](./config_README-ja.md)。
|
||||
- ``--dataset_config`` オプションで ``.toml`` ファイルを指定してください。
|
||||
- 今までのオプションはそのまま使えます。
|
||||
- 大規模なアップデートのため、もし不具合がありましたらご報告ください。
|
||||
- 学習の途中でサンプル画像を生成する機能を各学習スクリプトに追加しました。
|
||||
- ``--sample_every_n_steps`` と ``--sample_every_n_epochs`` オプション:生成頻度を指定
|
||||
- ``--sample_prompts`` オプション:プロンプトを記述したファイルを指定(1行ごとに1枚の画像を生成)
|
||||
- プロンプトには ``gen_img_diffusers.py`` のプロンプトオプションの一部、 ``w, h, d, l, s, n`` が使えます。
|
||||
- ``--sample_sampler`` オプション:ddim や k_euler などの sampler (scheduler) を指定します。使用できる sampler についてはヘルプをご覧ください。
|
||||
- ``--tokenizer_cache_dir`` オプションを各学習スクリプトおよび生成スクリプトに追加しました。Diffusers から Tokenizer を取得してきてろーかるに保存します。
|
||||
- 一度キャッシュしておくことでオフライン学習、生成ができるかもしれません。
|
||||
- ``gen_img_diffusers.py`` で highres. fix での letents upscaling と VAE のバッチサイズ指定に対応しました。
|
||||
- 最低限のメタデータ(module name, dim, alpha および network_args)が `--no_metadata` オプション指定時にも記録されます。issue https://github.com/kohya-ss/sd-scripts/issues/254
|
||||
- `train_network.py` で LoRAの Conv2d-3x3 拡張に対応しました(カーネルサイズ1x1以外のConv2dにも対象範囲を拡大します)。
|
||||
- 現在のバージョンの [LoCon](https://github.com/KohakuBlueleaf/LoCon) と同一の仕様です。__KohakuBlueleaf氏のご支援に深く感謝します。__
|
||||
- LoCon が将来的に拡張された場合、それらのバージョンでの互換性は保証できません。
|
||||
- `--network_args` オプションを `--network_args "conv_dim=4" "conv_alpha=1"` のように指定してください。
|
||||
- Stable Diffusion web UI での使用には [Additional Networks extension](https://github.com/kohya-ss/sd-webui-additional-networks) のversion 0.5.0 以降が必要です。
|
||||
- __Stable Diffusion web UI の LoRA 機能は LoRAの Conv2d-3x3 拡張に対応していないようです。使用するか否か慎重にご検討ください。__
|
||||
- マージ、抽出のスクリプトについても LoRA の Conv2d-3x3 拡張に対応しました.
|
||||
- サンプル画像生成後にCUDAメモリを解放しVRAM使用量を削減しました。 issue https://github.com/kohya-ss/sd-scripts/issues/260
|
||||
- 空のキャプションが使えるようになりました。 issue https://github.com/kohya-ss/sd-scripts/issues/258
|
||||
- Textual Inversion 学習でテンプレートを使ったとき、height/width が 8 で割り切れなかったときにサンプル画像生成がクラッシュするのを修正しました。
|
||||
- ドキュメント類を更新しました。
|
||||
|
||||
- Sample image generation:
|
||||
A prompt file might look like this, for example
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are not working.
|
||||
|
||||
- サンプル画像生成:
|
||||
プロンプトファイルは例えば以下のようになります。
|
||||
|
||||
```
|
||||
# prompt 1
|
||||
masterpiece, best quality, 1girl, in white shirts, upper body, looking at viewer, simple background --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 768 --h 768 --d 1 --l 7.5 --s 28
|
||||
|
||||
# prompt 2
|
||||
masterpiece, best quality, 1boy, in business suit, standing at street, looking back --n low quality, worst quality, bad anatomy,bad composition, poor, low effort --w 576 --h 832 --d 2 --l 5.5 --s 40
|
||||
```
|
||||
|
||||
`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
`( )` や `[ ]` などの重みづけは動作しません。
|
||||
|
||||
Please read [Releases](https://github.com/kohya-ss/sd-scripts/releases) for recent updates.
|
||||
最近の更新情報は [Release](https://github.com/kohya-ss/sd-scripts/releases) をご覧ください。
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import math
|
||||
import os
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
|
||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||
else:
|
||||
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
|
||||
|
||||
logging.set_verbosity_error() # don't show annoying warning
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
|
||||
print("loading text encoder:", info)
|
||||
|
||||
|
||||
@@ -924,7 +924,9 @@ class FineTuningDataset(BaseDataset):
|
||||
elif tags is not None and len(tags) > 0:
|
||||
caption = caption + ', ' + tags
|
||||
tags_list.append(tags)
|
||||
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get('train_resolution')
|
||||
@@ -2207,7 +2209,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
@@ -2353,6 +2355,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
height = max(64, height - height % 8) # round to divisible by 8
|
||||
width = max(64, width - width % 8) # round to divisible by 8
|
||||
print(f"prompt: {prompt}")
|
||||
print(f"negative_prompt: {negative_prompt}")
|
||||
print(f"height: {height}")
|
||||
|
||||
@@ -21,7 +21,7 @@ def main(file):
|
||||
|
||||
for key, value in values:
|
||||
value = value.to(torch.float32)
|
||||
print(f"{key},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||
print(f"{key},{str(tuple(value.size())).replace(', ', '-')},{torch.mean(torch.abs(value))},{torch.min(torch.abs(value))}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -45,8 +45,13 @@ def svd(args):
|
||||
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_tuned)
|
||||
|
||||
# create LoRA network to extract weights: Use dim (rank) as alpha
|
||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t)
|
||||
if args.conv_dim is None:
|
||||
kwargs = {}
|
||||
else:
|
||||
kwargs = {"conv_dim": args.conv_dim, "conv_alpha": args.conv_dim}
|
||||
|
||||
lora_network_o = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_o, unet_o, **kwargs)
|
||||
lora_network_t = lora.create_network(1.0, args.dim, args.dim, None, text_encoder_t, unet_t, **kwargs)
|
||||
assert len(lora_network_o.text_encoder_loras) == len(
|
||||
lora_network_t.text_encoder_loras), f"model version is different (SD1.x vs SD2.x) / それぞれのモデルのバージョンが違います(SD1.xベースとSD2.xベース) "
|
||||
|
||||
@@ -85,13 +90,27 @@ def svd(args):
|
||||
|
||||
# make LoRA with svd
|
||||
print("calculating by svd")
|
||||
rank = args.dim
|
||||
lora_weights = {}
|
||||
with torch.no_grad():
|
||||
for lora_name, mat in tqdm(list(diffs.items())):
|
||||
# if args.conv_dim is None, diffs do not include LoRAs for conv2d-3x3
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
|
||||
rank = args.dim if not conv2d_3x3 or args.conv_dim is None else args.conv_dim
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
if args.device:
|
||||
mat = mat.to(args.device)
|
||||
# print(mat.size(), mat.device, rank, in_dim, out_dim)
|
||||
rank = min(rank, in_dim, out_dim) # LoRA rank cannot exceed the original dim
|
||||
|
||||
if conv2d:
|
||||
mat = mat.squeeze()
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
@@ -108,6 +127,13 @@ def svd(args):
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, rank, 1, 1)
|
||||
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
U = U.to("cpu").contiguous()
|
||||
Vh = Vh.to("cpu").contiguous()
|
||||
|
||||
lora_weights[lora_name] = (U, Vh)
|
||||
|
||||
# make state dict for LoRA
|
||||
@@ -124,8 +150,8 @@ def svd(args):
|
||||
|
||||
weights = lora_weights[lora_name][i]
|
||||
# print(key, i, weights.size(), lora_sd[key].size())
|
||||
if len(lora_sd[key].size()) == 4:
|
||||
weights = weights.unsqueeze(2).unsqueeze(3)
|
||||
# if len(lora_sd[key].size()) == 4:
|
||||
# weights = weights.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
assert weights.size() == lora_sd[key].size(), f"size unmatch: {key}"
|
||||
lora_sd[key] = weights
|
||||
@@ -139,7 +165,7 @@ def svd(args):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
|
||||
# minimum metadata
|
||||
metadata = {"ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||
metadata = {"ss_network_module": "networks.lora", "ss_network_dim": str(args.dim), "ss_network_alpha": str(args.dim)}
|
||||
|
||||
lora_network_o.save_weights(args.save_to, save_dtype, metadata)
|
||||
print(f"LoRA weights are saved to: {args.save_to}")
|
||||
@@ -158,6 +184,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--save_to", type=str, default=None,
|
||||
help="destination file name: ckpt or safetensors file / 保存先のファイル名、ckptまたはsafetensors")
|
||||
parser.add_argument("--dim", type=int, default=4, help="dimension (rank) of LoRA (default 4) / LoRAの次元数(rank)(デフォルト4)")
|
||||
parser.add_argument("--conv_dim", type=int, default=None,
|
||||
help="dimension (rank) of LoRA for Conv2d-3x3 (default None, disabled) / LoRAのConv2d-3x3の次元数(rank)(デフォルトNone、適用なし)")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -70,7 +70,7 @@ class LoRAModule(torch.nn.Module):
|
||||
if self.region is None:
|
||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||
|
||||
# reginal LoRA
|
||||
# regional LoRA FIXME same as additional-network extension
|
||||
if x.size()[1] % 77 == 0:
|
||||
# print(f"LoRA for context: {self.lora_name}")
|
||||
self.region = None
|
||||
@@ -107,10 +107,14 @@ def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, un
|
||||
network_dim = 4 # default
|
||||
|
||||
# extract dim/alpha for conv2d, and block dim
|
||||
conv_dim = int(kwargs.get('conv_dim', network_dim))
|
||||
conv_alpha = kwargs.get('conv_alpha', network_alpha)
|
||||
if conv_alpha is not None:
|
||||
conv_alpha = float(conv_alpha)
|
||||
conv_dim = kwargs.get('conv_dim', None)
|
||||
conv_alpha = kwargs.get('conv_alpha', None)
|
||||
if conv_dim is not None:
|
||||
conv_dim = int(conv_dim)
|
||||
if conv_alpha is None:
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
"""
|
||||
block_dims = kwargs.get("block_dims")
|
||||
@@ -165,7 +169,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, **kwa
|
||||
elif 'lora_down' in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
print(lora_name, value.size(), dim)
|
||||
# print(lora_name, value.size(), dim)
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
|
||||
@@ -48,7 +48,7 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
for name, module in root_module.named_modules():
|
||||
if module.__class__.__name__ in target_replace_modules:
|
||||
for child_name, child_module in module.named_modules():
|
||||
if child_module.__class__.__name__ == "Linear" or (child_module.__class__.__name__ == "Conv2d" and child_module.kernel_size == (1, 1)):
|
||||
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
|
||||
lora_name = prefix + '.' + name + '.' + child_name
|
||||
lora_name = lora_name.replace('.', '_')
|
||||
name_to_module[lora_name] = child_module
|
||||
@@ -80,13 +80,19 @@ def merge_to_sd_model(text_encoder, unet, models, ratios, merge_dtype):
|
||||
|
||||
# W <- W + U * D
|
||||
weight = module.weight
|
||||
# print(module_name, down_weight.size(), up_weight.size())
|
||||
if len(weight.size()) == 2:
|
||||
# linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
# conv2d
|
||||
elif down_weight.size()[2:4] == (1, 1):
|
||||
# conv2d 1x1
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
else:
|
||||
# conv2d 3x3
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
# print(conved.size(), weight.size(), module.stride, module.padding)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
module.weight = torch.nn.Parameter(weight)
|
||||
|
||||
@@ -123,7 +129,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
alphas[lora_module_name] = alpha
|
||||
if lora_module_name not in base_alphas:
|
||||
base_alphas[lora_module_name] = alpha
|
||||
|
||||
|
||||
print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}")
|
||||
|
||||
# merge
|
||||
@@ -145,7 +151,7 @@ def merge_lora_models(models, ratios, merge_dtype):
|
||||
merged_sd[key] = merged_sd[key] + lora_sd[key] * scale
|
||||
else:
|
||||
merged_sd[key] = lora_sd[key] * scale
|
||||
|
||||
|
||||
# set alpha to sd
|
||||
for lora_module_name, alpha in base_alphas.items():
|
||||
key = lora_module_name + ".alpha"
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
# Thanks to cloneofsimo and kohya
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
from tqdm import tqdm
|
||||
from library import train_util, model_util
|
||||
import numpy as np
|
||||
|
||||
MIN_SV = 1e-6
|
||||
|
||||
def load_state_dict(file_name, dtype):
|
||||
if model_util.is_safetensors(file_name):
|
||||
@@ -38,12 +39,149 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
def index_sv_cumulative(S, target):
|
||||
original_sum = float(torch.sum(S))
|
||||
cumulative_sums = torch.cumsum(S, dim=0)/original_sum
|
||||
index = int(torch.searchsorted(cumulative_sums, target)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def index_sv_fro(S, target):
|
||||
S_squared = S.pow(2)
|
||||
s_fro_sq = float(torch.sum(S_squared))
|
||||
sum_S_squared = torch.cumsum(S_squared, dim=0)/s_fro_sq
|
||||
index = int(torch.searchsorted(sum_S_squared, target**2)) + 1
|
||||
if index >= len(S):
|
||||
index = len(S) - 1
|
||||
|
||||
return index
|
||||
|
||||
|
||||
# Modified from Kohaku-blueleaf's extract/merge functions
|
||||
def extract_conv(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size, kernel_size, _ = weight.size()
|
||||
U, S, Vh = torch.linalg.svd(weight.reshape(out_size, -1).to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size, kernel_size, kernel_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank, 1, 1).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def extract_linear(weight, lora_rank, dynamic_method, dynamic_param, device, scale=1):
|
||||
out_size, in_size = weight.size()
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.to(device))
|
||||
|
||||
param_dict = rank_resize(S, lora_rank, dynamic_method, dynamic_param, scale)
|
||||
lora_rank = param_dict["new_rank"]
|
||||
|
||||
U = U[:, :lora_rank]
|
||||
S = S[:lora_rank]
|
||||
U = U @ torch.diag(S)
|
||||
Vh = Vh[:lora_rank, :]
|
||||
|
||||
param_dict["lora_down"] = Vh.reshape(lora_rank, in_size).cpu()
|
||||
param_dict["lora_up"] = U.reshape(out_size, lora_rank).cpu()
|
||||
del U, S, Vh, weight
|
||||
return param_dict
|
||||
|
||||
|
||||
def merge_conv(lora_down, lora_up, device):
|
||||
in_rank, in_size, kernel_size, k_ = lora_down.shape
|
||||
out_size, out_rank, _, _ = lora_up.shape
|
||||
assert in_rank == out_rank and kernel_size == k_, f"rank {in_rank} {out_rank} or kernel {kernel_size} {k_} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
merged = lora_up.reshape(out_size, -1) @ lora_down.reshape(in_rank, -1)
|
||||
weight = merged.reshape(out_size, in_size, kernel_size, kernel_size)
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def merge_linear(lora_down, lora_up, device):
|
||||
in_rank, in_size = lora_down.shape
|
||||
out_size, out_rank = lora_up.shape
|
||||
assert in_rank == out_rank, f"rank {in_rank} {out_rank} mismatch"
|
||||
|
||||
lora_down = lora_down.to(device)
|
||||
lora_up = lora_up.to(device)
|
||||
|
||||
weight = lora_up @ lora_down
|
||||
del lora_up, lora_down
|
||||
return weight
|
||||
|
||||
|
||||
def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):
|
||||
param_dict = {}
|
||||
|
||||
if dynamic_method=="sv_ratio":
|
||||
# Calculate new dim and alpha based off ratio
|
||||
max_sv = S[0]
|
||||
min_sv = max_sv/dynamic_param
|
||||
new_rank = max(torch.sum(S > min_sv).item(),1)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_cumulative":
|
||||
# Calculate new dim and alpha based off cumulative sum
|
||||
new_rank = index_sv_cumulative(S, dynamic_param)
|
||||
new_rank = max(new_rank, 1)
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
elif dynamic_method=="sv_fro":
|
||||
# Calculate new dim and alpha based off sqrt sum of squares
|
||||
new_rank = index_sv_fro(S, dynamic_param)
|
||||
new_rank = min(max(new_rank, 1), len(S)-1)
|
||||
new_alpha = float(scale*new_rank)
|
||||
else:
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
|
||||
if S[0] <= MIN_SV: # Zero matrix, set dim to 1
|
||||
new_rank = 1
|
||||
new_alpha = float(scale*new_rank)
|
||||
elif new_rank > rank: # cap max rank at rank
|
||||
new_rank = rank
|
||||
new_alpha = float(scale*new_rank)
|
||||
|
||||
|
||||
# Calculate resize info
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
|
||||
S_squared = S.pow(2)
|
||||
s_fro = torch.sqrt(torch.sum(S_squared))
|
||||
s_red_fro = torch.sqrt(torch.sum(S_squared[:new_rank]))
|
||||
fro_percent = float(s_red_fro/s_fro)
|
||||
|
||||
param_dict["new_rank"] = new_rank
|
||||
param_dict["new_alpha"] = new_alpha
|
||||
param_dict["sum_retained"] = (s_rank)/s_sum
|
||||
param_dict["fro_retained"] = fro_percent
|
||||
param_dict["max_ratio"] = S[0]/S[new_rank]
|
||||
|
||||
return param_dict
|
||||
|
||||
|
||||
def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
|
||||
network_alpha = None
|
||||
network_dim = None
|
||||
verbose_str = "\n"
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
fro_list = []
|
||||
|
||||
# Extract loaded lora dim and alpha
|
||||
for key, value in lora_sd.items():
|
||||
@@ -57,9 +195,9 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
network_alpha = network_dim
|
||||
|
||||
scale = network_alpha/network_dim
|
||||
new_alpha = float(scale*new_rank) # calculate new alpha from scale
|
||||
|
||||
print(f"old dimension: {network_dim}, old alpha: {network_alpha}, new alpha: {new_alpha}")
|
||||
if dynamic_method:
|
||||
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
|
||||
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
@@ -68,7 +206,6 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
|
||||
print("resizing lora...")
|
||||
with torch.no_grad():
|
||||
for key, value in tqdm(lora_sd.items()):
|
||||
if 'lora_down' in key:
|
||||
@@ -85,57 +222,43 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, verbose):
|
||||
conv2d = (len(lora_down_weight.size()) == 4)
|
||||
|
||||
if conv2d:
|
||||
lora_down_weight = lora_down_weight.squeeze()
|
||||
lora_up_weight = lora_up_weight.squeeze()
|
||||
|
||||
if device:
|
||||
org_device = lora_up_weight.device
|
||||
lora_up_weight = lora_up_weight.to(args.device)
|
||||
lora_down_weight = lora_down_weight.to(args.device)
|
||||
|
||||
full_weight_matrix = torch.matmul(lora_up_weight, lora_down_weight)
|
||||
|
||||
U, S, Vh = torch.linalg.svd(full_weight_matrix)
|
||||
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_conv(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
else:
|
||||
full_weight_matrix = merge_linear(lora_down_weight, lora_up_weight, device)
|
||||
param_dict = extract_linear(full_weight_matrix, new_rank, dynamic_method, dynamic_param, device, scale)
|
||||
|
||||
if verbose:
|
||||
s_sum = torch.sum(torch.abs(S))
|
||||
s_rank = torch.sum(torch.abs(S[:new_rank]))
|
||||
verbose_str+=f"{block_down_name:76} | "
|
||||
verbose_str+=f"sum(S) retained: {(s_rank)/s_sum:.1%}, max(S) ratio: {S[0]/S[new_rank]:0.1f}\n"
|
||||
max_ratio = param_dict['max_ratio']
|
||||
sum_retained = param_dict['sum_retained']
|
||||
fro_retained = param_dict['fro_retained']
|
||||
if not np.isnan(fro_retained):
|
||||
fro_list.append(float(fro_retained))
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
verbose_str+=f"{block_down_name:75} | "
|
||||
verbose_str+=f"sum(S) retained: {sum_retained:.1%}, fro retained: {fro_retained:.1%}, max(S) ratio: {max_ratio:0.1f}"
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
if verbose and dynamic_method:
|
||||
verbose_str+=f", dynamic | dim: {param_dict['new_rank']}, alpha: {param_dict['new_alpha']}\n"
|
||||
else:
|
||||
verbose_str+=f"\n"
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
low_val = -hi_val
|
||||
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.unsqueeze(2).unsqueeze(3)
|
||||
Vh = Vh.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
if device:
|
||||
U = U.to(org_device)
|
||||
Vh = Vh.to(org_device)
|
||||
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = Vh.to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = U.to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(new_alpha).to(save_dtype)
|
||||
new_alpha = param_dict['new_alpha']
|
||||
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
|
||||
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict['new_alpha']).to(save_dtype)
|
||||
|
||||
block_down_name = None
|
||||
block_up_name = None
|
||||
lora_down_weight = None
|
||||
lora_up_weight = None
|
||||
weights_loaded = False
|
||||
del param_dict
|
||||
|
||||
if verbose:
|
||||
print(verbose_str)
|
||||
|
||||
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
|
||||
print("resizing complete")
|
||||
return o_lora_sd, network_dim, new_alpha
|
||||
|
||||
@@ -151,6 +274,9 @@ def resize(args):
|
||||
return torch.bfloat16
|
||||
return None
|
||||
|
||||
if args.dynamic_method and not args.dynamic_param:
|
||||
raise Exception("If using dynamic_method, then dynamic_param is required")
|
||||
|
||||
merge_dtype = str_to_dtype('float') # matmul method above only seems to work in float32
|
||||
save_dtype = str_to_dtype(args.save_precision)
|
||||
if save_dtype is None:
|
||||
@@ -159,17 +285,23 @@ def resize(args):
|
||||
print("loading Model...")
|
||||
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
|
||||
|
||||
print("resizing rank...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.verbose)
|
||||
print("Resizing Lora...")
|
||||
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
|
||||
|
||||
# update metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
comment = metadata.get("ss_training_comment", "")
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
|
||||
if not args.dynamic_method:
|
||||
metadata["ss_training_comment"] = f"dimension is resized from {old_dim} to {args.new_rank}; {comment}"
|
||||
metadata["ss_network_dim"] = str(args.new_rank)
|
||||
metadata["ss_network_alpha"] = str(new_alpha)
|
||||
else:
|
||||
metadata["ss_training_comment"] = f"Dynamic resize with {args.dynamic_method}: {args.dynamic_param} from {old_dim}; {comment}"
|
||||
metadata["ss_network_dim"] = 'Dynamic'
|
||||
metadata["ss_network_alpha"] = 'Dynamic'
|
||||
|
||||
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||
metadata["sshs_model_hash"] = model_hash
|
||||
@@ -193,6 +325,11 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Display verbose resizing information / rank変更時の詳細情報を出力する")
|
||||
parser.add_argument("--dynamic_method", type=str, default=None, choices=[None, "sv_ratio", "sv_fro", "sv_cumulative"],
|
||||
help="Specify dynamic resizing method, --new_rank is used as a hard limit for max rank")
|
||||
parser.add_argument("--dynamic_param", type=float, default=None,
|
||||
help="Specify target for dynamic reduction")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
resize(args)
|
||||
|
||||
@@ -35,7 +35,8 @@ def save_to_file(file_name, model, state_dict, dtype):
|
||||
torch.save(model, file_name)
|
||||
|
||||
|
||||
def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
def merge_lora_models(models, ratios, new_rank, new_conv_rank, device, merge_dtype):
|
||||
print(f"new rank: {new_rank}, new conv rank: {new_conv_rank}")
|
||||
merged_sd = {}
|
||||
for model, ratio in zip(models, ratios):
|
||||
print(f"loading: {model}")
|
||||
@@ -58,11 +59,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
in_dim = down_weight.size()[1]
|
||||
out_dim = up_weight.size()[0]
|
||||
conv2d = len(down_weight.size()) == 4
|
||||
print(lora_module_name, network_dim, alpha, in_dim, out_dim)
|
||||
kernel_size = None if not conv2d else down_weight.size()[2:4]
|
||||
# print(lora_module_name, network_dim, alpha, in_dim, out_dim, kernel_size)
|
||||
|
||||
# make original weight if not exist
|
||||
if lora_module_name not in merged_sd:
|
||||
weight = torch.zeros((out_dim, in_dim, 1, 1) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
weight = torch.zeros((out_dim, in_dim, *kernel_size) if conv2d else (out_dim, in_dim), dtype=merge_dtype)
|
||||
if device:
|
||||
weight = weight.to(device)
|
||||
else:
|
||||
@@ -77,9 +79,12 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
scale = (alpha / network_dim)
|
||||
if not conv2d: # linear
|
||||
weight = weight + ratio * (up_weight @ down_weight) * scale
|
||||
else:
|
||||
elif kernel_size == (1, 1):
|
||||
weight = weight + ratio * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)
|
||||
).unsqueeze(2).unsqueeze(3) * scale
|
||||
else:
|
||||
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||
weight = weight + ratio * conved * scale
|
||||
|
||||
merged_sd[lora_module_name] = weight
|
||||
|
||||
@@ -89,16 +94,25 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
with torch.no_grad():
|
||||
for lora_module_name, mat in tqdm(list(merged_sd.items())):
|
||||
conv2d = (len(mat.size()) == 4)
|
||||
kernel_size = None if not conv2d else mat.size()[2:4]
|
||||
conv2d_3x3 = conv2d and kernel_size != (1, 1)
|
||||
out_dim, in_dim = mat.size()[0:2]
|
||||
|
||||
if conv2d:
|
||||
mat = mat.squeeze()
|
||||
if conv2d_3x3:
|
||||
mat = mat.flatten(start_dim=1)
|
||||
else:
|
||||
mat = mat.squeeze()
|
||||
|
||||
module_new_rank = new_conv_rank if conv2d_3x3 else new_rank
|
||||
|
||||
U, S, Vh = torch.linalg.svd(mat)
|
||||
|
||||
U = U[:, :new_rank]
|
||||
S = S[:new_rank]
|
||||
U = U[:, :module_new_rank]
|
||||
S = S[:module_new_rank]
|
||||
U = U @ torch.diag(S)
|
||||
|
||||
Vh = Vh[:new_rank, :]
|
||||
Vh = Vh[:module_new_rank, :]
|
||||
|
||||
dist = torch.cat([U.flatten(), Vh.flatten()])
|
||||
hi_val = torch.quantile(dist, CLAMP_QUANTILE)
|
||||
@@ -107,16 +121,16 @@ def merge_lora_models(models, ratios, new_rank, device, merge_dtype):
|
||||
U = U.clamp(low_val, hi_val)
|
||||
Vh = Vh.clamp(low_val, hi_val)
|
||||
|
||||
if conv2d:
|
||||
U = U.reshape(out_dim, module_new_rank, 1, 1)
|
||||
Vh = Vh.reshape(module_new_rank, in_dim, kernel_size[0], kernel_size[1])
|
||||
|
||||
up_weight = U
|
||||
down_weight = Vh
|
||||
|
||||
if conv2d:
|
||||
up_weight = up_weight.unsqueeze(2).unsqueeze(3)
|
||||
down_weight = down_weight.unsqueeze(2).unsqueeze(3)
|
||||
|
||||
merged_lora_sd[lora_module_name + '.lora_up.weight'] = up_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.lora_down.weight'] = down_weight.to("cpu").contiguous()
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(new_rank)
|
||||
merged_lora_sd[lora_module_name + '.alpha'] = torch.tensor(module_new_rank)
|
||||
|
||||
return merged_lora_sd
|
||||
|
||||
@@ -138,7 +152,8 @@ def merge(args):
|
||||
if save_dtype is None:
|
||||
save_dtype = merge_dtype
|
||||
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, args.device, merge_dtype)
|
||||
new_conv_rank = args.new_conv_rank if args.new_conv_rank is not None else args.new_rank
|
||||
state_dict = merge_lora_models(args.models, args.ratios, args.new_rank, new_conv_rank, args.device, merge_dtype)
|
||||
|
||||
print(f"saving model to: {args.save_to}")
|
||||
save_to_file(args.save_to, state_dict, state_dict, save_dtype)
|
||||
@@ -158,6 +173,8 @@ if __name__ == '__main__':
|
||||
help="ratios for each model / それぞれのLoRAモデルの比率")
|
||||
parser.add_argument("--new_rank", type=int, default=4,
|
||||
help="Specify rank of output LoRA / 出力するLoRAのrank (dim)")
|
||||
parser.add_argument("--new_conv_rank", type=int, default=None,
|
||||
help="Specify rank of output LoRA for Conv2d 3x3, None for same as new_rank / 出力するConv2D 3x3 LoRAのrank (dim)、Noneでnew_rankと同じ")
|
||||
parser.add_argument("--device", type=str, default=None, help="device to use, cuda for GPU / 計算を行うデバイス、cuda でGPUを使う")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -74,7 +74,7 @@ accelerate launch --num_cpu_threads_per_process 1 train_db.py
|
||||
|
||||
### よく使われるオプションについて
|
||||
|
||||
以下の場合にはオプションに関するドキュメントを参照してください。
|
||||
以下の場合には [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」を参照してください。
|
||||
|
||||
- Stable Diffusion 2.xまたはそこからの派生モデルを学習する
|
||||
- clip skipを2以上を前提としたモデルを学習する
|
||||
|
||||
@@ -431,10 +431,13 @@ def train(args):
|
||||
"ss_bucket_info": json.dumps(dataset.bucket_info),
|
||||
})
|
||||
|
||||
# add extra args
|
||||
if args.network_args:
|
||||
for key, value in net_kwargs.items():
|
||||
metadata["ss_arg_" + key] = value
|
||||
metadata["ss_network_args"] = json.dumps(net_kwargs)
|
||||
# for key, value in net_kwargs.items():
|
||||
# metadata["ss_arg_" + key] = value
|
||||
|
||||
# model name and hash
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
sd_model_name = args.pretrained_model_name_or_path
|
||||
if os.path.exists(sd_model_name):
|
||||
@@ -453,6 +456,13 @@ def train(args):
|
||||
|
||||
metadata = {k: str(v) for k, v in metadata.items()}
|
||||
|
||||
# make minimum metadata for filtering
|
||||
minimum_keys = ["ss_network_module", "ss_network_dim", "ss_network_alpha", "ss_network_args"]
|
||||
minimum_metadata = {}
|
||||
for key in minimum_keys:
|
||||
if key in metadata:
|
||||
minimum_metadata[key] = metadata[key]
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
@@ -569,7 +579,7 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||
@@ -608,7 +618,7 @@ def train(args):
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model to {ckpt_file}")
|
||||
network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||
network.save_weights(ckpt_file, save_dtype, minimum_metadata if args.no_metadata else metadata)
|
||||
print("model saved.")
|
||||
|
||||
|
||||
|
||||
@@ -1,118 +1,99 @@
|
||||
## LoRAの学習について
|
||||
# LoRAの学習について
|
||||
|
||||
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)(arxiv)、[LoRA](https://github.com/microsoft/LoRA)(github)をStable Diffusionに適用したものです。
|
||||
|
||||
[cloneofsimo氏のリポジトリ](https://github.com/cloneofsimo/lora)を大いに参考にさせていただきました。ありがとうございます。
|
||||
|
||||
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
||||
|
||||
Conv2d 3x3への拡大は [cloneofsimo氏](https://github.com/cloneofsimo/lora) が最初にリリースし、KohakuBlueleaf氏が [LoCon](https://github.com/KohakuBlueleaf/LoCon) でその有効性を明らかにしたものです。KohakuBlueleaf氏に深く感謝します。
|
||||
|
||||
8GB VRAMでもぎりぎり動作するようです。
|
||||
|
||||
[学習についての共通ドキュメント](./train_README-ja.md) もあわせてご覧ください。
|
||||
|
||||
## 学習したモデルに関する注意
|
||||
|
||||
cloneofsimo氏のリポジトリ、およびd8ahazard氏の[Dreambooth Extension for Stable-Diffusion-WebUI](https://github.com/d8ahazard/sd_dreambooth_extension)とは、現時点では互換性がありません。いくつかの機能拡張を行っているためです(後述)。
|
||||
|
||||
WebUI等で画像生成する場合には、学習したLoRAのモデルを学習元のStable Diffusionのモデルにこのリポジトリ内のスクリプトであらかじめマージしておくか、こちらの[WebUI用extension](https://github.com/kohya-ss/sd-webui-additional-networks)を使ってください。
|
||||
|
||||
## 学習方法
|
||||
# 学習の手順
|
||||
|
||||
train_network.pyを用います。
|
||||
あらかじめこのリポジトリのREADMEを参照し、環境整備を行ってください。
|
||||
|
||||
DreamBoothの手法(identifier(sksなど)とclass、オプションで正則化画像を用いる)と、キャプションを用いるfine tuningの手法の両方で学習できます。
|
||||
## データの準備
|
||||
|
||||
どちらの方法も既存のスクリプトとほぼ同じ方法で学習できます。異なる点については後述します。
|
||||
[学習データの準備について](./train_README-ja.md) を参照してください。
|
||||
|
||||
### DreamBoothの手法を用いる場合
|
||||
|
||||
[DreamBoothのガイド](./train_db_README-ja.md) を参照してデータを用意してください。
|
||||
## 学習の実行
|
||||
|
||||
学習するとき、train_db.pyの代わりにtrain_network.pyを指定してください。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
||||
`train_network.py`を用います。
|
||||
|
||||
ほぼすべてのオプション(Stable Diffusionのモデル保存関係を除く)が使えますが、stop_text_encoder_trainingはサポートしていません。
|
||||
|
||||
### キャプションを用いる場合
|
||||
|
||||
[fine-tuningのガイド](./fine_tune_README_ja.md) を参照し、各手順を実行してください。
|
||||
|
||||
学習するとき、fine_tune.pyの代わりにtrain_network.pyを指定してください。ほぼすべてのオプション(モデル保存関係を除く)がそのまま使えます。そして「LoRAの学習のためのオプション」にあるようにLoRA関連のオプション(``network_dim``や``network_alpha``など)を追加してください。
|
||||
|
||||
なお「latentsの事前取得」は行わなくても動作します。VAEから学習時(またはキャッシュ時)にlatentを取得するため学習速度は遅くなりますが、代わりにcolor_augが使えるようになります。
|
||||
|
||||
### LoRAの学習のためのオプション
|
||||
|
||||
train_network.pyでは--network_moduleオプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
||||
`train_network.py`では `--network_module` オプションに、学習対象のモジュール名を指定します。LoRAに対応するのはnetwork.loraとなりますので、それを指定してください。
|
||||
|
||||
なお学習率は通常のDreamBoothやfine tuningよりも高めの、1e-4程度を指定するとよいようです。
|
||||
|
||||
以下はコマンドラインの例です(DreamBooth手法)。
|
||||
以下はコマンドラインの例です。
|
||||
|
||||
```
|
||||
accelerate launch --num_cpu_threads_per_process 1 train_network.py
|
||||
--pretrained_model_name_or_path=..\models\model.ckpt
|
||||
--train_data_dir=..\data\db\char1 --output_dir=..\lora_train1
|
||||
--reg_data_dir=..\data\db\reg1 --prior_loss_weight=1.0
|
||||
--resolution=448,640 --train_batch_size=1 --learning_rate=1e-4
|
||||
--max_train_steps=400 --optimizer_type=AdamW8bit --xformers --mixed_precision=fp16
|
||||
--save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --seed=42 --color_aug
|
||||
--pretrained_model_name_or_path=<.ckptまたは.safetensordまたはDiffusers版モデルのディレクトリ>
|
||||
--dataset_config=<データ準備で作成した.tomlファイル>
|
||||
--output_dir=<学習したモデルの出力先フォルダ>
|
||||
--output_name=<学習したモデル出力時のファイル名>
|
||||
--save_model_as=safetensors
|
||||
--prior_loss_weight=1.0
|
||||
--max_train_steps=400
|
||||
--learning_rate=1e-4
|
||||
--optimizer_type="AdamW8bit"
|
||||
--xformers
|
||||
--mixed_precision="fp16"
|
||||
--cache_latents
|
||||
--gradient_checkpointing
|
||||
--save_every_n_epochs=1
|
||||
--network_module=networks.lora
|
||||
```
|
||||
|
||||
(2023/2/22:オプティマイザの指定方法が変わりました。[こちら](#オプティマイザの指定について)をご覧ください。)
|
||||
|
||||
--output_dirオプションで指定したフォルダに、LoRAのモデルが保存されます。
|
||||
`--output_dir` オプションで指定したフォルダに、LoRAのモデルが保存されます。他のオプション、オプティマイザ等については [学習の共通ドキュメント](./train_README-ja.md) の「よく使われるオプション」も参照してください。
|
||||
|
||||
その他、以下のオプションが指定できます。
|
||||
|
||||
* --network_dim
|
||||
* `--network_dim`
|
||||
* LoRAのRANKを指定します(``--networkdim=4``など)。省略時は4になります。数が多いほど表現力は増しますが、学習に必要なメモリ、時間は増えます。また闇雲に増やしても良くないようです。
|
||||
* --network_alpha
|
||||
* `--network_alpha`
|
||||
* アンダーフローを防ぎ安定して学習するための ``alpha`` 値を指定します。デフォルトは1です。``network_dim``と同じ値を指定すると以前のバージョンと同じ動作になります。
|
||||
* --network_weights
|
||||
* `--network_weights`
|
||||
* 学習前に学習済みのLoRAの重みを読み込み、そこから追加で学習します。
|
||||
* --network_train_unet_only
|
||||
* `--network_train_unet_only`
|
||||
* U-Netに関連するLoRAモジュールのみ有効とします。fine tuning的な学習で指定するとよいかもしれません。
|
||||
* --network_train_text_encoder_only
|
||||
* `--network_train_text_encoder_only`
|
||||
* Text Encoderに関連するLoRAモジュールのみ有効とします。Textual Inversion的な効果が期待できるかもしれません。
|
||||
* --unet_lr
|
||||
* `--unet_lr`
|
||||
* U-Netに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。
|
||||
* --text_encoder_lr
|
||||
* `--text_encoder_lr`
|
||||
* Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。
|
||||
* `--network_args`
|
||||
* 複数の引数を指定できます。後述します。
|
||||
|
||||
--network_train_unet_onlyと--network_train_text_encoder_onlyの両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
`--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。
|
||||
|
||||
## オプティマイザの指定について
|
||||
## LoRA を Conv2d に拡大して適用する
|
||||
|
||||
--optimizer_type オプションでオプティマイザの種類を指定します。以下が指定できます。
|
||||
通常のLoRAは Linear およぴカーネルサイズ 1x1 の Conv2d にのみ適用されますが、カーネルサイズ 3x3 のConv2dに適用を拡大することもできます。
|
||||
|
||||
- AdamW : [torch.optim.AdamW](https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html)
|
||||
- 過去のバージョンのオプション未指定時と同じ
|
||||
- AdamW8bit : 引数は同上
|
||||
- 過去のバージョンの--use_8bit_adam指定時と同じ
|
||||
- Lion : https://github.com/lucidrains/lion-pytorch
|
||||
- 過去のバージョンの--use_lion_optimizer指定時と同じ
|
||||
- SGDNesterov : [torch.optim.SGD](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html), nesterov=True
|
||||
- SGDNesterov8bit : 引数は同上
|
||||
- DAdaptation : https://github.com/facebookresearch/dadaptation
|
||||
- AdaFactor : [Transformers AdaFactor](https://huggingface.co/docs/transformers/main_classes/optimizer_schedules)
|
||||
- 任意のオプティマイザ
|
||||
`--network_args` に以下のように指定してください。`conv_dim` で Conv2d (3x3) の rank を、`conv_alpha` で alpha を指定してください。
|
||||
|
||||
オプティマイザのオプション引数は--optimizer_argsオプションで指定してください。key=valueの形式で、複数の値が指定できます。また、valueはカンマ区切りで複数の値が指定できます。たとえばAdamWオプティマイザに引数を指定する場合は、``--optimizer_args weight_decay=0.01 betas=.9,.999``のようになります。
|
||||
```
|
||||
--network_args "conv_dim=1" "conv_alpha=1"
|
||||
```
|
||||
|
||||
オプション引数を指定する場合は、それぞれのオプティマイザの仕様をご確認ください。
|
||||
以下のように alpha 省略時は1になります。
|
||||
|
||||
一部のオプティマイザでは必須の引数があり、省略すると自動的に追加されます(SGDNesterovのmomentumなど)。コンソールの出力を確認してください。
|
||||
|
||||
D-Adaptationオプティマイザは学習率を自動調整します。学習率のオプションに指定した値は学習率そのものではなくD-Adaptationが決定した学習率の適用率になりますので、通常は1.0を指定してください。Text EncoderにU-Netの半分の学習率を指定したい場合は、``--text_encoder_lr=0.5 --unet_lr=1.0``と指定します。
|
||||
|
||||
AdaFactorオプティマイザはrelative_step=Trueを指定すると学習率を自動調整できます(省略時はデフォルトで追加されます)。自動調整する場合は学習率のスケジューラにはadafactor_schedulerが強制的に使用されます。またscale_parameterとwarmup_initを指定するとよいようです。
|
||||
|
||||
自動調整する場合のオプション指定はたとえば ``--optimizer_args "relative_step=True" "scale_parameter=True" "warmup_init=True"`` のようになります。
|
||||
|
||||
学習率を自動調整しない場合はオプション引数 ``relative_step=False`` を追加してください。その場合、学習率のスケジューラにはconstant_with_warmupが、また勾配のclip normをしないことが推奨されているようです。そのため引数は ``--optimizer_type=adafactor --optimizer_args "relative_step=False" --lr_scheduler="constant_with_warmup" --max_grad_norm=0.0`` のようになります。
|
||||
|
||||
### 任意のオプティマイザを使う
|
||||
|
||||
``torch.optim`` のオプティマイザを使う場合にはクラス名のみを(``--optimizer_type=RMSprop``など)、他のモジュールのオプティマイザを使う時は「モジュール名.クラス名」を指定してください(``--optimizer_type=bitsandbytes.optim.lamb.LAMB``など)。
|
||||
|
||||
(内部でimportlibしているだけで動作は未確認です。必要ならパッケージをインストールしてください。)
|
||||
```
|
||||
--network_args "conv_dim=1"
|
||||
```
|
||||
|
||||
## マージスクリプトについて
|
||||
|
||||
@@ -176,6 +157,27 @@ v1で学習したLoRAとv2で学習したLoRA、rank(次元数)や``alpha``
|
||||
* save_precision
|
||||
* モデル保存時の精度をfloat、fp16、bf16から指定できます。省略時はprecisionと同じ精度になります。
|
||||
|
||||
|
||||
## 複数のrankが異なるLoRAのモデルをマージする
|
||||
|
||||
複数のLoRAをひとつのLoRAで近似します(完全な再現はできません)。`svd_merge_lora.py`を用います。たとえば以下のようなコマンドラインになります。
|
||||
|
||||
```
|
||||
python networks\svd_merge_lora.py
|
||||
--save_to ..\lora_train1\model-char1-style1-merged.safetensors
|
||||
--models ..\lora_train1\last.safetensors ..\lora_train2\last.safetensors
|
||||
--ratios 0.6 0.4 --new_rank 32 --device cuda
|
||||
```
|
||||
|
||||
`merge_lora.py` と主なオプションは同一です。以下のオプションが追加されています。
|
||||
|
||||
- `--new_rank`
|
||||
- 作成するLoRAのrankを指定します。
|
||||
- `--new_conv_rank`
|
||||
- 作成する Conv2d 3x3 LoRA の rank を指定します。省略時は `new_rank` と同じになります。
|
||||
- `--device`
|
||||
- `--device cuda`としてcudaを指定すると計算をGPU上で行います。処理が速くなります。
|
||||
|
||||
## 当リポジトリ内の画像生成スクリプトで生成する
|
||||
|
||||
gen_img_diffusers.pyに、--network_module、--network_weightsの各オプションを追加してください。意味は学習時と同様です。
|
||||
@@ -209,12 +211,14 @@ Text Encoderが二つのモデルで同じ場合にはLoRAはU-NetのみのLoRA
|
||||
|
||||
### その他のオプション
|
||||
|
||||
- --v2
|
||||
- `--v2`
|
||||
- v2.xのStable Diffusionモデルを使う場合に指定してください。
|
||||
- --device
|
||||
- `--device`
|
||||
- ``--device cuda``としてcudaを指定すると計算をGPU上で行います。処理が速くなります(CPUでもそこまで遅くないため、せいぜい倍~数倍程度のようです)。
|
||||
- --save_precision
|
||||
- `--save_precision`
|
||||
- LoRAの保存形式を"float", "fp16", "bf16"から指定します。省略時はfloatになります。
|
||||
- `--conv_dim`
|
||||
- 指定するとLoRAの適用範囲を Conv2d 3x3 へ拡大します。Conv2d 3x3 の rank を指定します。
|
||||
|
||||
## 画像リサイズスクリプト
|
||||
|
||||
@@ -252,7 +256,7 @@ python tools\resize_images_to_resolution.py --max_resolution 512x512,384x384,256
|
||||
|
||||
### cloneofsimo氏のリポジトリとの違い
|
||||
|
||||
12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
||||
2022/12/25時点では、当リポジトリはLoRAの適用個所をText EncoderのMLP、U-NetのFFN、Transformerのin/out projectionに拡大し、表現力が増しています。ただその代わりメモリ使用量は増え、8GBぎりぎりになりました。
|
||||
|
||||
またモジュール入れ替え機構は全く異なります。
|
||||
|
||||
|
||||
@@ -181,6 +181,11 @@ def train(args):
|
||||
for tmpl in templates:
|
||||
captions.append(tmpl.format(replace_to))
|
||||
train_dataset_group.add_replacement("", captions)
|
||||
|
||||
if args.num_vectors_per_token > 1:
|
||||
prompt_replacement = (args.token_string, replace_to)
|
||||
else:
|
||||
prompt_replacement = None
|
||||
else:
|
||||
if args.num_vectors_per_token > 1:
|
||||
replace_to = " ".join(token_strings)
|
||||
|
||||
Reference in New Issue
Block a user