mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-14 16:22:28 +00:00
Compare commits
22 Commits
wuerstchen
...
multi_embe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f8629e3c1a | ||
|
|
33ee0acd35 | ||
|
|
8b79e3b06c | ||
|
|
cf49e912fc | ||
|
|
66741c035c | ||
|
|
406511c333 | ||
|
|
8a2d68d63e | ||
|
|
07d297fdbe | ||
|
|
0d4e8b50d0 | ||
|
|
1d7c5c2a98 | ||
|
|
0faa350175 | ||
|
|
8a7509db75 | ||
|
|
025368f51c | ||
|
|
5fe52ed322 | ||
|
|
8b247a330b | ||
|
|
d6f458fcb3 | ||
|
|
b8b84021e5 | ||
|
|
70fe7e18be | ||
|
|
9378da3c82 | ||
|
|
a4857fa764 | ||
|
|
592014923f | ||
|
|
6d06b215bf |
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: typos-action
|
- name: typos-action
|
||||||
uses: crate-ci/typos@v1.16.15
|
uses: crate-ci/typos@v1.16.15
|
||||||
|
|||||||
34
README.md
34
README.md
@@ -249,6 +249,40 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
|||||||
|
|
||||||
## Change History
|
## Change History
|
||||||
|
|
||||||
|
### Oct 9. 2023 / 2023/10/9
|
||||||
|
|
||||||
|
- `tag_images_by_wd_14_tagger.py` now supports Onnx. If you use Onnx, TensorFlow is not required anymore. [#864](https://github.com/kohya-ss/sd-scripts/pull/864) Thanks to Isotr0py!
|
||||||
|
- `--onnx` option is added. If you use Onnx, specify `--onnx` option.
|
||||||
|
- Please install Onnx and other required packages.
|
||||||
|
1. Uninstall TensorFlow.
|
||||||
|
1. `pip install tensorboard==2.14.1` This is required for the specified version of protobuf.
|
||||||
|
1. `pip install protobuf==3.20.3` This is required for Onnx.
|
||||||
|
1. `pip install onnx==1.14.1`
|
||||||
|
1. `pip install onnxruntime-gpu==1.16.0` or `pip install onnxruntime==1.16.0`
|
||||||
|
- `--append_tags` option is added to `tag_images_by_wd_14_tagger.py`. This option appends the tags to the existing tags, instead of replacing them. [#858](https://github.com/kohya-ss/sd-scripts/pull/858) Thanks to a-l-e-x-d-s-9!
|
||||||
|
- [OFT](https://oft.wyliu.com/) is now supported.
|
||||||
|
- You can use `networks.oft` for the network module in `sdxl_train_network.py`. The usage is the same as `networks.lora`. Some options are not supported.
|
||||||
|
- `sdxl_gen_img.py` also supports OFT as `--network_module`.
|
||||||
|
- OFT only supports SDXL currently. Because current OFT tweaks Q/K/V and O in the transformer, and SD1/2 have extremely fewer transformers than SDXL.
|
||||||
|
- The implementation is heavily based on laksjdjf's [OFT implementation](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py). Thanks to laksjdjf!
|
||||||
|
- Other bug fixes and improvements.
|
||||||
|
|
||||||
|
- `tag_images_by_wd_14_tagger.py` が Onnx をサポートしました。Onnx を使用する場合は TensorFlow は不要です。[#864](https://github.com/kohya-ss/sd-scripts/pull/864) Isotr0py氏に感謝します。
|
||||||
|
- Onnxを使用する場合は、`--onnx` オプションを指定してください。
|
||||||
|
- Onnx とその他の必要なパッケージをインストールしてください。
|
||||||
|
1. TensorFlow をアンインストールしてください。
|
||||||
|
1. `pip install tensorboard==2.14.1` protobufの指定バージョンにこれが必要。
|
||||||
|
1. `pip install protobuf==3.20.3` Onnxのために必要。
|
||||||
|
1. `pip install onnx==1.14.1`
|
||||||
|
1. `pip install onnxruntime-gpu==1.16.0` または `pip install onnxruntime==1.16.0`
|
||||||
|
- `tag_images_by_wd_14_tagger.py` に `--append_tags` オプションが追加されました。このオプションを指定すると、既存のタグに上書きするのではなく、新しいタグのみが既存のタグに追加されます。 [#858](https://github.com/kohya-ss/sd-scripts/pull/858) a-l-e-x-d-s-9氏に感謝します。
|
||||||
|
- [OFT](https://oft.wyliu.com/) をサポートしました。
|
||||||
|
- `sdxl_train_network.py` の`--network_module`に `networks.oft` を指定してください。使用方法は `networks.lora` と同様ですが一部のオプションは未サポートです。
|
||||||
|
- `sdxl_gen_img.py` でも同様に OFT を指定できます。
|
||||||
|
- OFT は現在 SDXL のみサポートしています。OFT は現在 transformer の Q/K/V と O を変更しますが、SD1/2 は transformer の数が SDXL よりも極端に少ないためです。
|
||||||
|
- 実装は laksjdjf 氏の [OFT実装](https://github.com/laksjdjf/sd-trainer/blob/dev/networks/lora_modules.py) を多くの部分で参考にしています。laksjdjf 氏に感謝します。
|
||||||
|
- その他のバグ修正と改善。
|
||||||
|
|
||||||
### Oct 1. 2023 / 2023/10/1
|
### Oct 1. 2023 / 2023/10/1
|
||||||
|
|
||||||
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.
|
- SDXL training is now available in the main branch. The sdxl branch is merged into the main branch.
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import csv
|
import csv
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
|
||||||
from tensorflow.keras.models import load_model
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
|
||||||
# from wd14 tagger
|
# from wd14 tagger
|
||||||
@@ -20,6 +18,7 @@ IMAGE_SIZE = 448
|
|||||||
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
# wd-v1-4-swinv2-tagger-v2 / wd-v1-4-vit-tagger / wd-v1-4-vit-tagger-v2/ wd-v1-4-convnext-tagger / wd-v1-4-convnext-tagger-v2
|
||||||
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
DEFAULT_WD14_TAGGER_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
|
||||||
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
FILES = ["keras_metadata.pb", "saved_model.pb", "selected_tags.csv"]
|
||||||
|
FILES_ONNX = ["model.onnx"]
|
||||||
SUB_DIR = "variables"
|
SUB_DIR = "variables"
|
||||||
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
SUB_DIR_FILES = ["variables.data-00000-of-00001", "variables.index"]
|
||||||
CSV_FILE = FILES[-1]
|
CSV_FILE = FILES[-1]
|
||||||
@@ -81,7 +80,10 @@ def main(args):
|
|||||||
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/issues/22
|
||||||
if not os.path.exists(args.model_dir) or args.force_download:
|
if not os.path.exists(args.model_dir) or args.force_download:
|
||||||
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
print(f"downloading wd14 tagger model from hf_hub. id: {args.repo_id}")
|
||||||
for file in FILES:
|
files = FILES
|
||||||
|
if args.onnx:
|
||||||
|
files += FILES_ONNX
|
||||||
|
for file in files:
|
||||||
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
hf_hub_download(args.repo_id, file, cache_dir=args.model_dir, force_download=True, force_filename=file)
|
||||||
for file in SUB_DIR_FILES:
|
for file in SUB_DIR_FILES:
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
@@ -96,7 +98,46 @@ def main(args):
|
|||||||
print("using existing wd14 tagger model")
|
print("using existing wd14 tagger model")
|
||||||
|
|
||||||
# 画像を読み込む
|
# 画像を読み込む
|
||||||
model = load_model(args.model_dir)
|
if args.onnx:
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
|
||||||
|
onnx_path = f"{args.model_dir}/model.onnx"
|
||||||
|
print("Running wd14 tagger with onnx")
|
||||||
|
print(f"loading onnx model: {onnx_path}")
|
||||||
|
|
||||||
|
if not os.path.exists(onnx_path):
|
||||||
|
raise Exception(
|
||||||
|
f"onnx model not found: {onnx_path}, please redownload the model with --force_download"
|
||||||
|
+ " / onnxモデルが見つかりませんでした。--force_downloadで再ダウンロードしてください"
|
||||||
|
)
|
||||||
|
|
||||||
|
model = onnx.load(onnx_path)
|
||||||
|
input_name = model.graph.input[0].name
|
||||||
|
try:
|
||||||
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_value
|
||||||
|
except:
|
||||||
|
batch_size = model.graph.input[0].type.tensor_type.shape.dim[0].dim_param
|
||||||
|
|
||||||
|
if args.batch_size != batch_size and type(batch_size) != str:
|
||||||
|
# some rebatch model may use 'N' as dynamic axes
|
||||||
|
print(
|
||||||
|
f"Batch size {args.batch_size} doesn't match onnx model batch size {batch_size}, use model batch size {batch_size}"
|
||||||
|
)
|
||||||
|
args.batch_size = batch_size
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
ort_sess = ort.InferenceSession(
|
||||||
|
onnx_path,
|
||||||
|
providers=["CUDAExecutionProvider"]
|
||||||
|
if "CUDAExecutionProvider" in ort.get_available_providers()
|
||||||
|
else ["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
|
||||||
|
model = load_model(f"{args.model_dir}")
|
||||||
|
|
||||||
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
# label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
|
||||||
# 依存ライブラリを増やしたくないので自力で読むよ
|
# 依存ライブラリを増やしたくないので自力で読むよ
|
||||||
@@ -124,8 +165,14 @@ def main(args):
|
|||||||
def run_batch(path_imgs):
|
def run_batch(path_imgs):
|
||||||
imgs = np.array([im for _, im in path_imgs])
|
imgs = np.array([im for _, im in path_imgs])
|
||||||
|
|
||||||
probs = model(imgs, training=False)
|
if args.onnx:
|
||||||
probs = probs.numpy()
|
if len(imgs) < args.batch_size:
|
||||||
|
imgs = np.concatenate([imgs, np.zeros((args.batch_size - len(imgs), IMAGE_SIZE, IMAGE_SIZE, 3))], axis=0)
|
||||||
|
probs = ort_sess.run(None, {input_name: imgs})[0] # onnx output numpy
|
||||||
|
probs = probs[: len(path_imgs)]
|
||||||
|
else:
|
||||||
|
probs = model(imgs, training=False)
|
||||||
|
probs = probs.numpy()
|
||||||
|
|
||||||
for (image_path, _), prob in zip(path_imgs, probs):
|
for (image_path, _), prob in zip(path_imgs, probs):
|
||||||
# 最初の4つはratingなので無視する
|
# 最初の4つはratingなので無視する
|
||||||
@@ -165,9 +212,27 @@ def main(args):
|
|||||||
if len(character_tag_text) > 0:
|
if len(character_tag_text) > 0:
|
||||||
character_tag_text = character_tag_text[2:]
|
character_tag_text = character_tag_text[2:]
|
||||||
|
|
||||||
|
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||||
|
|
||||||
tag_text = ", ".join(combined_tags)
|
tag_text = ", ".join(combined_tags)
|
||||||
|
|
||||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
if args.append_tags:
|
||||||
|
# Check if file exists
|
||||||
|
if os.path.exists(caption_file):
|
||||||
|
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||||
|
# Read file and remove new lines
|
||||||
|
existing_content = f.read().strip("\n") # Remove newlines
|
||||||
|
|
||||||
|
# Split the content into tags and store them in a list
|
||||||
|
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
|
||||||
|
|
||||||
|
# Check and remove repeating tags in tag_text
|
||||||
|
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
|
||||||
|
|
||||||
|
# Create new tag_text
|
||||||
|
tag_text = ", ".join(existing_tags + new_tags)
|
||||||
|
|
||||||
|
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||||
f.write(tag_text + "\n")
|
f.write(tag_text + "\n")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||||
@@ -283,12 +348,15 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
help="comma-separated list of undesired tags to remove from the output / 出力から除外したいタグのカンマ区切りのリスト",
|
||||||
)
|
)
|
||||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||||
|
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference / onnxモデルを推論に使用する")
|
||||||
|
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# スペルミスしていたオプションを復元する
|
# スペルミスしていたオプションを復元する
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# JPEG-XL on Linux
|
||||||
try:
|
try:
|
||||||
from jxlpy import JXLImagePlugin
|
from jxlpy import JXLImagePlugin
|
||||||
|
|
||||||
@@ -103,6 +104,14 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# JPEG-XL on Windows
|
||||||
|
try:
|
||||||
|
import pillow_jxl
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
IMAGE_TRANSFORMS = transforms.Compose(
|
IMAGE_TRANSFORMS = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
@@ -1995,7 +2004,7 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
|
|
||||||
if show_input_ids:
|
if show_input_ids:
|
||||||
print(f"input ids: {iid}")
|
print(f"input ids: {iid}")
|
||||||
if "input_ids2" in example and example["input_ids2"] is not None:
|
if "input_ids2" in example:
|
||||||
print(f"input ids2: {example['input_ids2'][j]}")
|
print(f"input ids2: {example['input_ids2'][j]}")
|
||||||
if example["images"] is not None:
|
if example["images"] is not None:
|
||||||
im = example["images"][j]
|
im = example["images"][j]
|
||||||
@@ -2012,11 +2021,6 @@ def debug_dataset(train_dataset, show_input_ids=False):
|
|||||||
cond_img = cond_img[:, :, ::-1]
|
cond_img = cond_img[:, :, ::-1]
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
cv2.imshow("cond_img", cond_img)
|
cv2.imshow("cond_img", cond_img)
|
||||||
|
|
||||||
for key in example.keys():
|
|
||||||
if key in ["images", "conditioning_images", "input_ids", "input_ids2"]:
|
|
||||||
continue
|
|
||||||
print(f"{key}: {example[key][j] if example[key] is not None else None}")
|
|
||||||
|
|
||||||
if os.name == "nt": # only windows
|
if os.name == "nt": # only windows
|
||||||
cv2.imshow("img", im)
|
cv2.imshow("img", im)
|
||||||
|
|||||||
430
networks/oft.py
Normal file
430
networks/oft.py
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
# OFT network module
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
from diffusers import AutoencoderKL
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
|
|
||||||
|
class OFTModule(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
oft_name,
|
||||||
|
org_module: torch.nn.Module,
|
||||||
|
multiplier=1.0,
|
||||||
|
dim=4,
|
||||||
|
alpha=1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
dim -> num blocks
|
||||||
|
alpha -> constraint
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.oft_name = oft_name
|
||||||
|
|
||||||
|
self.num_blocks = dim
|
||||||
|
|
||||||
|
if "Linear" in org_module.__class__.__name__:
|
||||||
|
out_dim = org_module.out_features
|
||||||
|
elif "Conv" in org_module.__class__.__name__:
|
||||||
|
out_dim = org_module.out_channels
|
||||||
|
|
||||||
|
if type(alpha) == torch.Tensor:
|
||||||
|
alpha = alpha.detach().numpy()
|
||||||
|
self.constraint = alpha * out_dim
|
||||||
|
self.register_buffer("alpha", torch.tensor(alpha))
|
||||||
|
|
||||||
|
self.block_size = out_dim // self.num_blocks
|
||||||
|
self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size))
|
||||||
|
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.shape = org_module.weight.shape
|
||||||
|
|
||||||
|
self.multiplier = multiplier
|
||||||
|
self.org_module = [org_module] # moduleにならないようにlistに入れる
|
||||||
|
|
||||||
|
def apply_to(self):
|
||||||
|
self.org_forward = self.org_module[0].forward
|
||||||
|
self.org_module[0].forward = self.forward
|
||||||
|
|
||||||
|
def get_weight(self, multiplier=None):
|
||||||
|
if multiplier is None:
|
||||||
|
multiplier = self.multiplier
|
||||||
|
|
||||||
|
block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||||
|
norm_Q = torch.norm(block_Q.flatten())
|
||||||
|
new_norm_Q = torch.clamp(norm_Q, max=self.constraint)
|
||||||
|
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
|
||||||
|
I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1)
|
||||||
|
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
|
||||||
|
|
||||||
|
block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I
|
||||||
|
R = torch.block_diag(*block_R_weighted)
|
||||||
|
|
||||||
|
return R
|
||||||
|
|
||||||
|
def forward(self, x, scale=None):
|
||||||
|
x = self.org_forward(x)
|
||||||
|
if self.multiplier == 0.0:
|
||||||
|
return x
|
||||||
|
|
||||||
|
R = self.get_weight().to(x.device, dtype=x.dtype)
|
||||||
|
if x.dim() == 4:
|
||||||
|
x = x.permute(0, 2, 3, 1)
|
||||||
|
x = torch.matmul(x, R)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
else:
|
||||||
|
x = torch.matmul(x, R)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class OFTInfModule(OFTModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
oft_name,
|
||||||
|
org_module: torch.nn.Module,
|
||||||
|
multiplier=1.0,
|
||||||
|
dim=4,
|
||||||
|
alpha=1,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
# no dropout for inference
|
||||||
|
super().__init__(oft_name, org_module, multiplier, dim, alpha)
|
||||||
|
self.enabled = True
|
||||||
|
self.network: OFTNetwork = None
|
||||||
|
|
||||||
|
def set_network(self, network):
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def forward(self, x, scale=None):
|
||||||
|
if not self.enabled:
|
||||||
|
return self.org_forward(x)
|
||||||
|
return super().forward(x, scale)
|
||||||
|
|
||||||
|
def merge_to(self, multiplier=None, sign=1):
|
||||||
|
R = self.get_weight(multiplier) * sign
|
||||||
|
|
||||||
|
# get org weight
|
||||||
|
org_sd = self.org_module[0].state_dict()
|
||||||
|
org_weight = org_sd["weight"]
|
||||||
|
R = R.to(org_weight.device, dtype=org_weight.dtype)
|
||||||
|
|
||||||
|
if org_weight.dim() == 4:
|
||||||
|
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
|
||||||
|
else:
|
||||||
|
weight = torch.einsum("oi, op -> pi", org_weight, R)
|
||||||
|
|
||||||
|
# set weight to org_module
|
||||||
|
org_sd["weight"] = weight
|
||||||
|
self.org_module[0].load_state_dict(org_sd)
|
||||||
|
|
||||||
|
|
||||||
|
def create_network(
|
||||||
|
multiplier: float,
|
||||||
|
network_dim: Optional[int],
|
||||||
|
network_alpha: Optional[float],
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
text_encoder: Union[CLIPTextModel, List[CLIPTextModel]],
|
||||||
|
unet,
|
||||||
|
neuron_dropout: Optional[float] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if network_dim is None:
|
||||||
|
network_dim = 4 # default
|
||||||
|
if network_alpha is None:
|
||||||
|
network_alpha = 1.0
|
||||||
|
|
||||||
|
enable_all_linear = kwargs.get("enable_all_linear", None)
|
||||||
|
enable_conv = kwargs.get("enable_conv", None)
|
||||||
|
if enable_all_linear is not None:
|
||||||
|
enable_all_linear = bool(enable_all_linear)
|
||||||
|
if enable_conv is not None:
|
||||||
|
enable_conv = bool(enable_conv)
|
||||||
|
|
||||||
|
network = OFTNetwork(
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
multiplier=multiplier,
|
||||||
|
dim=network_dim,
|
||||||
|
alpha=network_alpha,
|
||||||
|
enable_all_linear=enable_all_linear,
|
||||||
|
enable_conv=enable_conv,
|
||||||
|
varbose=True,
|
||||||
|
)
|
||||||
|
return network
|
||||||
|
|
||||||
|
|
||||||
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||||
|
if weights_sd is None:
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file, safe_open
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
# check dim, alpha and if weights have for conv2d
|
||||||
|
dim = None
|
||||||
|
alpha = None
|
||||||
|
has_conv2d = None
|
||||||
|
all_linear = None
|
||||||
|
for name, param in weights_sd.items():
|
||||||
|
if name.endswith(".alpha"):
|
||||||
|
if alpha is None:
|
||||||
|
alpha = param.item()
|
||||||
|
else:
|
||||||
|
if dim is None:
|
||||||
|
dim = param.size()[0]
|
||||||
|
if has_conv2d is None and param.dim() == 4:
|
||||||
|
has_conv2d = True
|
||||||
|
if all_linear is None:
|
||||||
|
if param.dim() == 3 and "attn" not in name:
|
||||||
|
all_linear = True
|
||||||
|
if dim is not None and alpha is not None and has_conv2d is not None:
|
||||||
|
break
|
||||||
|
if has_conv2d is None:
|
||||||
|
has_conv2d = False
|
||||||
|
if all_linear is None:
|
||||||
|
all_linear = False
|
||||||
|
|
||||||
|
module_class = OFTInfModule if for_inference else OFTModule
|
||||||
|
network = OFTNetwork(
|
||||||
|
text_encoder,
|
||||||
|
unet,
|
||||||
|
multiplier=multiplier,
|
||||||
|
dim=dim,
|
||||||
|
alpha=alpha,
|
||||||
|
enable_all_linear=all_linear,
|
||||||
|
enable_conv=has_conv2d,
|
||||||
|
module_class=module_class,
|
||||||
|
)
|
||||||
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
|
class OFTNetwork(torch.nn.Module):
|
||||||
|
UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"]
|
||||||
|
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||||
|
OFT_PREFIX_UNET = "oft_unet" # これ変えないほうがいいかな
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_encoder: Union[List[CLIPTextModel], CLIPTextModel],
|
||||||
|
unet,
|
||||||
|
multiplier: float = 1.0,
|
||||||
|
dim: int = 4,
|
||||||
|
alpha: float = 1,
|
||||||
|
enable_all_linear: Optional[bool] = False,
|
||||||
|
enable_conv: Optional[bool] = False,
|
||||||
|
module_class: Type[object] = OFTModule,
|
||||||
|
varbose: Optional[bool] = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.multiplier = multiplier
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.alpha = alpha
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# create module instances
|
||||||
|
def create_modules(
|
||||||
|
root_module: torch.nn.Module,
|
||||||
|
target_replace_modules: List[torch.nn.Module],
|
||||||
|
) -> List[OFTModule]:
|
||||||
|
prefix = self.OFT_PREFIX_UNET
|
||||||
|
ofts = []
|
||||||
|
for name, module in root_module.named_modules():
|
||||||
|
if module.__class__.__name__ in target_replace_modules:
|
||||||
|
for child_name, child_module in module.named_modules():
|
||||||
|
is_linear = "Linear" in child_module.__class__.__name__
|
||||||
|
is_conv2d = "Conv2d" in child_module.__class__.__name__
|
||||||
|
is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1)
|
||||||
|
|
||||||
|
if is_linear or is_conv2d_1x1 or (is_conv2d and enable_conv):
|
||||||
|
oft_name = prefix + "." + name + "." + child_name
|
||||||
|
oft_name = oft_name.replace(".", "_")
|
||||||
|
# print(oft_name)
|
||||||
|
|
||||||
|
oft = module_class(
|
||||||
|
oft_name,
|
||||||
|
child_module,
|
||||||
|
self.multiplier,
|
||||||
|
dim,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
ofts.append(oft)
|
||||||
|
return ofts
|
||||||
|
|
||||||
|
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
|
||||||
|
if enable_all_linear:
|
||||||
|
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR
|
||||||
|
else:
|
||||||
|
target_modules = OFTNetwork.UNET_TARGET_REPLACE_MODULE_ATTN_ONLY
|
||||||
|
if enable_conv:
|
||||||
|
target_modules += OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||||
|
|
||||||
|
self.unet_ofts: List[OFTModule] = create_modules(unet, target_modules)
|
||||||
|
print(f"create OFT for U-Net: {len(self.unet_ofts)} modules.")
|
||||||
|
|
||||||
|
# assertion
|
||||||
|
names = set()
|
||||||
|
for oft in self.unet_ofts:
|
||||||
|
assert oft.oft_name not in names, f"duplicated oft name: {oft.oft_name}"
|
||||||
|
names.add(oft.oft_name)
|
||||||
|
|
||||||
|
def set_multiplier(self, multiplier):
|
||||||
|
self.multiplier = multiplier
|
||||||
|
for oft in self.unet_ofts:
|
||||||
|
oft.multiplier = self.multiplier
|
||||||
|
|
||||||
|
def load_weights(self, file):
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
weights_sd = load_file(file)
|
||||||
|
else:
|
||||||
|
weights_sd = torch.load(file, map_location="cpu")
|
||||||
|
|
||||||
|
info = self.load_state_dict(weights_sd, False)
|
||||||
|
return info
|
||||||
|
|
||||||
|
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
|
||||||
|
assert apply_unet, "apply_unet must be True"
|
||||||
|
|
||||||
|
for oft in self.unet_ofts:
|
||||||
|
oft.apply_to()
|
||||||
|
self.add_module(oft.oft_name, oft)
|
||||||
|
|
||||||
|
# マージできるかどうかを返す
|
||||||
|
def is_mergeable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# TODO refactor to common function with apply_to
|
||||||
|
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||||
|
print("enable OFT for U-Net")
|
||||||
|
|
||||||
|
for oft in self.unet_ofts:
|
||||||
|
sd_for_lora = {}
|
||||||
|
for key in weights_sd.keys():
|
||||||
|
if key.startswith(oft.oft_name):
|
||||||
|
sd_for_lora[key[len(oft.oft_name) + 1 :]] = weights_sd[key]
|
||||||
|
oft.load_state_dict(sd_for_lora, False)
|
||||||
|
oft.merge_to()
|
||||||
|
|
||||||
|
print(f"weights are merged")
|
||||||
|
|
||||||
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
|
self.requires_grad_(True)
|
||||||
|
all_params = []
|
||||||
|
|
||||||
|
def enumerate_params(ofts):
|
||||||
|
params = []
|
||||||
|
for oft in ofts:
|
||||||
|
params.extend(oft.parameters())
|
||||||
|
|
||||||
|
# print num of params
|
||||||
|
num_params = 0
|
||||||
|
for p in params:
|
||||||
|
num_params += p.numel()
|
||||||
|
print(f"OFT params: {num_params}")
|
||||||
|
return params
|
||||||
|
|
||||||
|
param_data = {"params": enumerate_params(self.unet_ofts)}
|
||||||
|
if unet_lr is not None:
|
||||||
|
param_data["lr"] = unet_lr
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
|
return all_params
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
# not supported
|
||||||
|
pass
|
||||||
|
|
||||||
|
def prepare_grad_etc(self, text_encoder, unet):
|
||||||
|
self.requires_grad_(True)
|
||||||
|
|
||||||
|
def on_epoch_start(self, text_encoder, unet):
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
def get_trainable_params(self):
|
||||||
|
return self.parameters()
|
||||||
|
|
||||||
|
def save_weights(self, file, dtype, metadata):
|
||||||
|
if metadata is not None and len(metadata) == 0:
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
|
||||||
|
if dtype is not None:
|
||||||
|
for key in list(state_dict.keys()):
|
||||||
|
v = state_dict[key]
|
||||||
|
v = v.detach().clone().to("cpu").to(dtype)
|
||||||
|
state_dict[key] = v
|
||||||
|
|
||||||
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
|
# Precalculate model hashes to save time on indexing
|
||||||
|
if metadata is None:
|
||||||
|
metadata = {}
|
||||||
|
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
|
||||||
|
metadata["sshs_model_hash"] = model_hash
|
||||||
|
metadata["sshs_legacy_hash"] = legacy_hash
|
||||||
|
|
||||||
|
save_file(state_dict, file, metadata)
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
|
def backup_weights(self):
|
||||||
|
# 重みのバックアップを行う
|
||||||
|
ofts: List[OFTInfModule] = self.unet_ofts
|
||||||
|
for oft in ofts:
|
||||||
|
org_module = oft.org_module[0]
|
||||||
|
if not hasattr(org_module, "_lora_org_weight"):
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def restore_weights(self):
|
||||||
|
# 重みのリストアを行う
|
||||||
|
ofts: List[OFTInfModule] = self.unet_ofts
|
||||||
|
for oft in ofts:
|
||||||
|
org_module = oft.org_module[0]
|
||||||
|
if not org_module._lora_restored:
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
sd["weight"] = org_module._lora_org_weight
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def pre_calculation(self):
|
||||||
|
# 事前計算を行う
|
||||||
|
ofts: List[OFTInfModule] = self.unet_ofts
|
||||||
|
for oft in ofts:
|
||||||
|
org_module = oft.org_module[0]
|
||||||
|
oft.merge_to()
|
||||||
|
# sd = org_module.state_dict()
|
||||||
|
# org_weight = sd["weight"]
|
||||||
|
# lora_weight = oft.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||||
|
# sd["weight"] = org_weight + lora_weight
|
||||||
|
# assert sd["weight"].shape == org_weight.shape
|
||||||
|
# org_module.load_state_dict(sd)
|
||||||
|
|
||||||
|
org_module._lora_restored = False
|
||||||
|
oft.enabled = False
|
||||||
@@ -19,8 +19,14 @@ huggingface-hub==0.15.1
|
|||||||
# requests==2.28.2
|
# requests==2.28.2
|
||||||
# timm==0.6.12
|
# timm==0.6.12
|
||||||
# fairscale==0.4.13
|
# fairscale==0.4.13
|
||||||
# for WD14 captioning
|
# for WD14 captioning (tensorflow)
|
||||||
# tensorflow==2.10.1
|
# tensorflow==2.10.1
|
||||||
|
# for WD14 captioning (onnx)
|
||||||
|
# onnx==1.14.1
|
||||||
|
# onnxruntime-gpu==1.16.0
|
||||||
|
# onnxruntime==1.16.0
|
||||||
|
# this is for onnx:
|
||||||
|
# protobuf==3.20.3
|
||||||
# open clip for SDXL
|
# open clip for SDXL
|
||||||
open-clip-torch==2.20.0
|
open-clip-torch==2.20.0
|
||||||
# for kohya_ss library
|
# for kohya_ss library
|
||||||
|
|||||||
84
tools/split_ti_embeddings.py
Normal file
84
tools/split_ti_embeddings.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def split(args):
|
||||||
|
# load embedding
|
||||||
|
if args.embedding.endswith(".safetensors"):
|
||||||
|
embedding = load_file(args.embedding)
|
||||||
|
with safe_open(args.embedding, framework="pt") as f:
|
||||||
|
metadata = f.metadata()
|
||||||
|
else:
|
||||||
|
embedding = torch.load(args.embedding)
|
||||||
|
metadata = None
|
||||||
|
|
||||||
|
# check format
|
||||||
|
if "emb_params" in embedding:
|
||||||
|
# SD1/2
|
||||||
|
keys = ["emb_params"]
|
||||||
|
elif "clip_l" in embedding:
|
||||||
|
# SDXL
|
||||||
|
keys = ["clip_l", "clip_g"]
|
||||||
|
else:
|
||||||
|
print("Unknown embedding format")
|
||||||
|
exit()
|
||||||
|
num_vectors = embedding[keys[0]].shape[0]
|
||||||
|
|
||||||
|
# prepare output directory
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# prepare splits
|
||||||
|
if args.vectors_per_split is not None:
|
||||||
|
num_splits = (num_vectors + args.vectors_per_split - 1) // args.vectors_per_split
|
||||||
|
vectors_for_split = [args.vectors_per_split] * num_splits
|
||||||
|
if sum(vectors_for_split) > num_vectors:
|
||||||
|
vectors_for_split[-1] -= sum(vectors_for_split) - num_vectors
|
||||||
|
assert sum(vectors_for_split) == num_vectors
|
||||||
|
elif args.vectors is not None:
|
||||||
|
vectors_for_split = args.vectors
|
||||||
|
num_splits = len(vectors_for_split)
|
||||||
|
else:
|
||||||
|
print("Must specify either --vectors_per_split or --vectors / --vectors_per_split または --vectors のどちらかを指定する必要があります")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
sum(vectors_for_split) == num_vectors
|
||||||
|
), "Sum of vectors must be equal to the number of vectors in the embedding / 分割したベクトルの合計はembeddingのベクトル数と等しくなければなりません"
|
||||||
|
|
||||||
|
# split
|
||||||
|
basename = os.path.splitext(os.path.basename(args.embedding))[0]
|
||||||
|
done_vectors = 0
|
||||||
|
for i, num_vectors in enumerate(vectors_for_split):
|
||||||
|
print(f"Splitting {num_vectors} vectors...")
|
||||||
|
|
||||||
|
split_embedding = {}
|
||||||
|
for key in keys:
|
||||||
|
split_embedding[key] = embedding[key][done_vectors : done_vectors + num_vectors]
|
||||||
|
|
||||||
|
output_file = os.path.join(args.output_dir, f"{basename}_{i}.safetensors")
|
||||||
|
save_file(split_embedding, output_file, metadata)
|
||||||
|
print(f"Saved to {output_file}")
|
||||||
|
|
||||||
|
done_vectors += num_vectors
|
||||||
|
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Merge models")
|
||||||
|
parser.add_argument("--embedding", type=str, help="Embedding to split")
|
||||||
|
parser.add_argument("--output_dir", type=str, help="Output directory")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vectors_per_split",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Number of vectors per split. If num_vectors is 8 and vectors_per_split is 3, then 3, 3, 2 vectors will be split",
|
||||||
|
)
|
||||||
|
parser.add_argument("--vectors", type=int, default=None, nargs="*", help="number of vectors for each split. e.g. 3 3 2")
|
||||||
|
args = parser.parse_args()
|
||||||
|
split(args)
|
||||||
@@ -283,7 +283,10 @@ class NetworkTrainer:
|
|||||||
if args.dim_from_weights:
|
if args.dim_from_weights:
|
||||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||||
else:
|
else:
|
||||||
# LyCORIS will work with this...
|
if "dropout" not in net_kwargs:
|
||||||
|
# workaround for LyCORIS (;^ω^)
|
||||||
|
net_kwargs["dropout"] = args.network_dropout
|
||||||
|
|
||||||
network = network_module.create_network(
|
network = network_module.create_network(
|
||||||
1.0,
|
1.0,
|
||||||
args.network_dim,
|
args.network_dim,
|
||||||
|
|||||||
@@ -7,10 +7,13 @@ import toml
|
|||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -167,6 +170,13 @@ class TextualInversionTrainer:
|
|||||||
args.output_name = args.token_string
|
args.output_name = args.token_string
|
||||||
use_template = args.use_object_template or args.use_style_template
|
use_template = args.use_object_template or args.use_style_template
|
||||||
|
|
||||||
|
assert (
|
||||||
|
args.token_string is not None or args.token_strings is not None
|
||||||
|
), "token_string or token_strings must be specified / token_stringまたはtoken_stringsを指定してください"
|
||||||
|
assert (
|
||||||
|
not use_template or args.token_strings is None
|
||||||
|
), "token_strings cannot be used with template / token_stringsはテンプレートと一緒に使えません"
|
||||||
|
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
@@ -215,9 +225,17 @@ class TextualInversionTrainer:
|
|||||||
# add new word to tokenizer, count is num_vectors_per_token
|
# add new word to tokenizer, count is num_vectors_per_token
|
||||||
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
# if token_string is hoge, "hoge", "hoge1", "hoge2", ... are added
|
||||||
|
|
||||||
self.assert_token_string(args.token_string, tokenizers)
|
if args.token_strings is not None:
|
||||||
|
token_strings = args.token_strings
|
||||||
|
assert (
|
||||||
|
len(token_strings) == args.num_vectors_per_token
|
||||||
|
), f"num_vectors_per_token is mismatch for token_strings / token_stringsの数がnum_vectors_per_tokenと合いません: {len(token_strings)}"
|
||||||
|
for token_string in token_strings:
|
||||||
|
self.assert_token_string(token_string, tokenizers)
|
||||||
|
else:
|
||||||
|
self.assert_token_string(args.token_string, tokenizers)
|
||||||
|
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
||||||
|
|
||||||
token_strings = [args.token_string] + [f"{args.token_string}{i+1}" for i in range(args.num_vectors_per_token - 1)]
|
|
||||||
token_ids_list = []
|
token_ids_list = []
|
||||||
token_embeds_list = []
|
token_embeds_list = []
|
||||||
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
for i, (tokenizer, text_encoder, init_token_ids) in enumerate(zip(tokenizers, text_encoders, init_token_ids_list)):
|
||||||
@@ -332,7 +350,7 @@ class TextualInversionTrainer:
|
|||||||
prompt_replacement = None
|
prompt_replacement = None
|
||||||
else:
|
else:
|
||||||
# サンプル生成用
|
# サンプル生成用
|
||||||
if args.num_vectors_per_token > 1:
|
if args.num_vectors_per_token > 1 and args.token_strings is None:
|
||||||
replace_to = " ".join(token_strings)
|
replace_to = " ".join(token_strings)
|
||||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||||
prompt_replacement = (args.token_string, replace_to)
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
@@ -752,6 +770,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
help="token string used in training, must not exist in tokenizer / 学習時に使用されるトークン文字列、tokenizerに存在しない文字であること",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token_strings",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
nargs="*",
|
||||||
|
help="token strings used in training for multiple embedding / 複数のembeddingsの個別学習時に使用されるトークン文字列",
|
||||||
|
)
|
||||||
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
parser.add_argument("--init_word", type=str, default=None, help="words to initialize vector / ベクトルを初期化に使用する単語、複数可")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_object_template",
|
"--use_object_template",
|
||||||
|
|||||||
@@ -1,196 +0,0 @@
|
|||||||
# use Diffusers' pipeline to generate images
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import datetime
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
from einops import repeat
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
from tqdm import tqdm
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import CLIPTextModel, PreTrainedTokenizerFast
|
|
||||||
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_prior import WuerstchenPrior
|
|
||||||
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
|
|
||||||
|
|
||||||
# from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
|
|
||||||
from wuerstchen_train import EfficientNetEncoder
|
|
||||||
|
|
||||||
|
|
||||||
def generate(args):
|
|
||||||
dtype = torch.float32
|
|
||||||
if args.fp16:
|
|
||||||
dtype = torch.float16
|
|
||||||
elif args.bf16:
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
device = args.device
|
|
||||||
|
|
||||||
os.makedirs(args.outdir, exist_ok=True)
|
|
||||||
|
|
||||||
# load tokenizer
|
|
||||||
print("load tokenizer")
|
|
||||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
|
|
||||||
|
|
||||||
# load text encoder
|
|
||||||
print("load text encoder")
|
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
|
||||||
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# load prior model
|
|
||||||
print("load prior model")
|
|
||||||
prior: WuerstchenPrior = WuerstchenPrior.from_pretrained(
|
|
||||||
args.pretrained_prior_model_name_or_path, subfolder="prior", torch_dtype=dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Diffusers版のxformers使用フラグを設定する関数
|
|
||||||
def set_diffusers_xformers_flag(model, valid):
|
|
||||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
|
||||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
|
||||||
module.set_use_memory_efficient_attention_xformers(valid)
|
|
||||||
|
|
||||||
for child in module.children():
|
|
||||||
fn_recursive_set_mem_eff(child)
|
|
||||||
|
|
||||||
fn_recursive_set_mem_eff(model)
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
|
||||||
if args.diffusers_xformers:
|
|
||||||
print("Use xformers by Diffusers")
|
|
||||||
set_diffusers_xformers_flag(prior, True)
|
|
||||||
|
|
||||||
# load pipeline
|
|
||||||
print("load pipeline")
|
|
||||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
||||||
args.pretrained_decoder_model_name_or_path,
|
|
||||||
prior_prior=prior,
|
|
||||||
prior_text_encoder=text_encoder,
|
|
||||||
prior_tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
pipeline = pipeline.to(device, torch_dtype=dtype)
|
|
||||||
|
|
||||||
# generate image
|
|
||||||
while True:
|
|
||||||
width = args.w
|
|
||||||
height = args.h
|
|
||||||
seed = args.seed
|
|
||||||
negative_prompt = None
|
|
||||||
|
|
||||||
if args.interactive:
|
|
||||||
print("prompt:")
|
|
||||||
prompt = input()
|
|
||||||
if prompt == "":
|
|
||||||
break
|
|
||||||
|
|
||||||
# parse prompt
|
|
||||||
prompt_args = prompt.split(" --")
|
|
||||||
prompt = prompt_args[0]
|
|
||||||
|
|
||||||
for parg in prompt_args[1:]:
|
|
||||||
try:
|
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
width = int(m.group(1))
|
|
||||||
print(f"width: {width}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
|
||||||
if m:
|
|
||||||
height = int(m.group(1))
|
|
||||||
print(f"height: {height}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
|
|
||||||
if m: # seed
|
|
||||||
seed = int(m.group(1))
|
|
||||||
print(f"seed: {seed}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
|
||||||
if m: # negative prompt
|
|
||||||
negative_prompt = m.group(1)
|
|
||||||
print(f"negative prompt: {negative_prompt}")
|
|
||||||
continue
|
|
||||||
except ValueError as ex:
|
|
||||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
|
||||||
print(ex)
|
|
||||||
else:
|
|
||||||
prompt = args.prompt
|
|
||||||
negative_prompt = args.negative_prompt
|
|
||||||
|
|
||||||
if seed is None:
|
|
||||||
generator = None
|
|
||||||
else:
|
|
||||||
generator = torch.Generator(device=device).manual_seed(seed)
|
|
||||||
|
|
||||||
with torch.autocast(device):
|
|
||||||
image = pipeline(
|
|
||||||
prompt,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
# prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
|
||||||
generator=generator,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
).images[0]
|
|
||||||
|
|
||||||
# save image
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
||||||
image.save(os.path.join(args.outdir, f"image_{timestamp}.png"))
|
|
||||||
|
|
||||||
if not args.interactive:
|
|
||||||
break
|
|
||||||
|
|
||||||
print("Done!")
|
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# train_util.add_sd_models_arguments(parser)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_prior_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default="warp-ai/wuerstchen-prior",
|
|
||||||
required=False,
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_decoder_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default="warp-ai/wuerstchen",
|
|
||||||
required=False,
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--prompt", type=str, default="A photo of a cat")
|
|
||||||
parser.add_argument("--negative_prompt", type=str, default="")
|
|
||||||
parser.add_argument("--outdir", type=str, default=".")
|
|
||||||
parser.add_argument("--w", type=int, default=1024)
|
|
||||||
parser.add_argument("--h", type=int, default=1024)
|
|
||||||
parser.add_argument("--interactive", action="store_true")
|
|
||||||
parser.add_argument("--fp16", action="store_true")
|
|
||||||
parser.add_argument("--bf16", action="store_true")
|
|
||||||
parser.add_argument("--device", type=str, default="cuda")
|
|
||||||
parser.add_argument("--seed", type=int, default=None)
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = setup_parser()
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
generate(args)
|
|
||||||
@@ -1,648 +0,0 @@
|
|||||||
# training with captions
|
|
||||||
# heavily based on https://github.com/kashif/diffusers
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import gc
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from multiprocessing import Value
|
|
||||||
from typing import List
|
|
||||||
import toml
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
import torch
|
|
||||||
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
|
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
try:
|
|
||||||
import intel_extension_for_pytorch as ipex
|
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
|
||||||
from library.ipex import ipex_init
|
|
||||||
|
|
||||||
ipex_init()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from accelerate.utils import set_seed
|
|
||||||
from transformers import CLIPTextModel, PreTrainedTokenizerFast
|
|
||||||
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
|
||||||
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_prior import WuerstchenPrior
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
import library.train_util as train_util
|
|
||||||
import library.config_util as config_util
|
|
||||||
from library.config_util import (
|
|
||||||
ConfigSanitizer,
|
|
||||||
BlueprintGenerator,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EfficientNetEncoder(ModelMixin, ConfigMixin):
|
|
||||||
@register_to_config
|
|
||||||
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if effnet == "efficientnet_v2_s":
|
|
||||||
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
|
|
||||||
else:
|
|
||||||
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
|
|
||||||
self.mapper = torch.nn.Sequential(
|
|
||||||
torch.nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
|
|
||||||
torch.nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.mapper(self.backbone(x))
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetWrapper(train_util.DatasetGroup):
|
|
||||||
r"""
|
|
||||||
Wrapper for datasets to be used with DataLoader.
|
|
||||||
add effnet_pixel_values and text_mask to dataset.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# なんかうまいことやればattributeをコピーしなくてもいい気がする
|
|
||||||
|
|
||||||
def __init__(self, dataset, tokenizer):
|
|
||||||
self.dataset = dataset
|
|
||||||
self.image_data = dataset.image_data
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.num_train_images = dataset.num_train_images
|
|
||||||
self.datasets = dataset.datasets
|
|
||||||
|
|
||||||
# images are already resized
|
|
||||||
self.effnet_transforms = transforms.Compose(
|
|
||||||
[
|
|
||||||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
item = self.dataset[idx]
|
|
||||||
|
|
||||||
# create attention mask by input_ids
|
|
||||||
input_ids = item["input_ids"]
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
attention_mask[input_ids == self.tokenizer.pad_token_id] = 0
|
|
||||||
text_mask = attention_mask.bool()
|
|
||||||
item["text_mask"] = text_mask
|
|
||||||
|
|
||||||
# create effnet input
|
|
||||||
images = item["images"]
|
|
||||||
# effnet_pixel_values = [self.effnet_transforms(image) for image in images]
|
|
||||||
# effnet_pixel_values = torch.stack(effnet_pixel_values, dim=0)
|
|
||||||
effnet_pixel_values = self.effnet_transforms(((images) + 1.0) / 2.0)
|
|
||||||
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format)
|
|
||||||
item["effnet_pixel_values"] = effnet_pixel_values
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.dataset)
|
|
||||||
|
|
||||||
def add_replacement(self, str_from, str_to):
|
|
||||||
self.dataset.add_replacement(str_from, str_to)
|
|
||||||
|
|
||||||
def enable_XTI(self, *args, **kwargs):
|
|
||||||
self.dataset.enable_XTI(*args, **kwargs)
|
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
|
||||||
self.dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
|
||||||
|
|
||||||
def cache_text_encoder_outputs(
|
|
||||||
self, tokenizers, text_encoders, device, weight_dtype, cache_to_disk=False, is_main_process=True
|
|
||||||
):
|
|
||||||
self.dataset.cache_text_encoder_outputs(tokenizers, text_encoders, device, weight_dtype, cache_to_disk, is_main_process)
|
|
||||||
|
|
||||||
def set_caching_mode(self, caching_mode):
|
|
||||||
self.dataset.set_caching_mode(caching_mode)
|
|
||||||
|
|
||||||
def verify_bucket_reso_steps(self, min_steps: int):
|
|
||||||
self.dataset.verify_bucket_reso_steps(min_steps)
|
|
||||||
|
|
||||||
def is_latent_cacheable(self) -> bool:
|
|
||||||
return self.dataset.is_latent_cacheable()
|
|
||||||
|
|
||||||
def is_text_encoder_output_cacheable(self) -> bool:
|
|
||||||
return self.dataset.is_text_encoder_output_cacheable()
|
|
||||||
|
|
||||||
def set_current_epoch(self, epoch):
|
|
||||||
self.dataset.set_current_epoch(epoch)
|
|
||||||
|
|
||||||
def set_current_step(self, step):
|
|
||||||
self.dataset.set_current_step(step)
|
|
||||||
|
|
||||||
def set_max_train_steps(self, max_train_steps):
|
|
||||||
self.dataset.set_max_train_steps(max_train_steps)
|
|
||||||
|
|
||||||
def disable_token_padding(self):
|
|
||||||
self.dataset.disable_token_padding()
|
|
||||||
|
|
||||||
|
|
||||||
def get_hidden_states(args: argparse.Namespace, input_ids, text_mask, tokenizer, text_encoder, weight_dtype=None):
|
|
||||||
# with no_token_padding, the length is not max length, return result immediately
|
|
||||||
if input_ids.size()[-1] != tokenizer.model_max_length:
|
|
||||||
return text_encoder(input_ids, attention_mask=text_mask)[0]
|
|
||||||
|
|
||||||
# input_ids: b,n,77
|
|
||||||
b_size = input_ids.size()[0]
|
|
||||||
input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
|
||||||
text_mask = text_mask.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
|
|
||||||
|
|
||||||
if args.clip_skip is None:
|
|
||||||
encoder_hidden_states = text_encoder(input_ids)[0]
|
|
||||||
else:
|
|
||||||
enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
|
|
||||||
encoder_hidden_states = enc_out["hidden_states"][-args.clip_skip]
|
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
|
||||||
|
|
||||||
# bs*3, 77, 768 or 1024
|
|
||||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
|
||||||
|
|
||||||
if args.max_token_length is not None:
|
|
||||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
|
||||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
|
||||||
for i in range(1, args.max_token_length, tokenizer.model_max_length):
|
|
||||||
states_list.append(encoder_hidden_states[:, i : i + tokenizer.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
|
||||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
|
||||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
|
||||||
|
|
||||||
if weight_dtype is not None:
|
|
||||||
# this is required for additional network training
|
|
||||||
encoder_hidden_states = encoder_hidden_states.to(weight_dtype)
|
|
||||||
|
|
||||||
return encoder_hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
|
||||||
# TODO: add checking for unsupported args
|
|
||||||
# TODO: cache image encoder outputs instead of latents
|
|
||||||
|
|
||||||
# train_util.verify_training_args(args)
|
|
||||||
train_util.prepare_dataset_args(args, True)
|
|
||||||
|
|
||||||
use_dreambooth_method = args.in_json is None
|
|
||||||
|
|
||||||
if args.seed is not None:
|
|
||||||
set_seed(args.seed) # 乱数系列を初期化する
|
|
||||||
|
|
||||||
print("prepare tokenizer")
|
|
||||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")
|
|
||||||
|
|
||||||
# データセットを準備する
|
|
||||||
if args.dataset_class is None:
|
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
|
||||||
if args.dataset_config is not None:
|
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
|
||||||
ignored = ["train_data_dir", "in_json"]
|
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
|
||||||
print(
|
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
|
||||||
", ".join(ignored)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if use_dreambooth_method:
|
|
||||||
print("Using DreamBooth method.")
|
|
||||||
user_config = {
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
|
||||||
args.train_data_dir, args.reg_data_dir
|
|
||||||
)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
print("Training with captions.")
|
|
||||||
user_config = {
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"subsets": [
|
|
||||||
{
|
|
||||||
"image_dir": args.train_data_dir,
|
|
||||||
"metadata_file": args.in_json,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
|
||||||
else:
|
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
|
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
|
||||||
current_step = Value("i", 0)
|
|
||||||
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
|
||||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
|
||||||
|
|
||||||
train_dataset_group.verify_bucket_reso_steps(32)
|
|
||||||
|
|
||||||
# wrap for wuestchen
|
|
||||||
train_dataset_group = DatasetWrapper(train_dataset_group, tokenizer)
|
|
||||||
|
|
||||||
if args.debug_dataset:
|
|
||||||
train_util.debug_dataset(train_dataset_group, True)
|
|
||||||
return
|
|
||||||
if len(train_dataset_group) == 0:
|
|
||||||
print(
|
|
||||||
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# acceleratorを準備する
|
|
||||||
print("prepare accelerator")
|
|
||||||
accelerator = train_util.prepare_accelerator(args)
|
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
|
||||||
weight_dtype, _ = train_util.prepare_dtype(args)
|
|
||||||
|
|
||||||
# Load scheduler, effnet, tokenizer, clip_model
|
|
||||||
print("prepare scheduler, effnet, clip_model")
|
|
||||||
noise_scheduler = DDPMWuerstchenScheduler()
|
|
||||||
|
|
||||||
# TODO support explicit local caching for faster loading
|
|
||||||
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
|
|
||||||
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
|
|
||||||
image_encoder = EfficientNetEncoder()
|
|
||||||
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
|
|
||||||
image_encoder.eval()
|
|
||||||
|
|
||||||
text_encoder = CLIPTextModel.from_pretrained(
|
|
||||||
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Freeze text_encoder and image_encoder
|
|
||||||
text_encoder.requires_grad_(False)
|
|
||||||
image_encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
# load prior model
|
|
||||||
prior: WuerstchenPrior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
|
||||||
|
|
||||||
# EMA is not supported yet
|
|
||||||
|
|
||||||
# Diffusers版のxformers使用フラグを設定する関数
|
|
||||||
def set_diffusers_xformers_flag(model, valid):
|
|
||||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
|
||||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
|
||||||
module.set_use_memory_efficient_attention_xformers(valid)
|
|
||||||
|
|
||||||
for child in module.children():
|
|
||||||
fn_recursive_set_mem_eff(child)
|
|
||||||
|
|
||||||
fn_recursive_set_mem_eff(model)
|
|
||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
|
||||||
if args.diffusers_xformers:
|
|
||||||
accelerator.print("Use xformers by Diffusers")
|
|
||||||
set_diffusers_xformers_flag(prior, True)
|
|
||||||
|
|
||||||
# 学習を準備する
|
|
||||||
|
|
||||||
# 学習を準備する:モデルを適切な状態にする
|
|
||||||
training_models = []
|
|
||||||
if args.gradient_checkpointing:
|
|
||||||
# prior.enable_gradient_checkpointing()
|
|
||||||
print("*" * 80)
|
|
||||||
print("*** Prior model does not support gradient checkpointing. ***")
|
|
||||||
print("*" * 80)
|
|
||||||
training_models.append(prior)
|
|
||||||
|
|
||||||
text_encoder.requires_grad_(False)
|
|
||||||
text_encoder.eval()
|
|
||||||
|
|
||||||
for m in training_models:
|
|
||||||
m.requires_grad_(True)
|
|
||||||
|
|
||||||
params = []
|
|
||||||
for m in training_models:
|
|
||||||
params.extend(m.parameters())
|
|
||||||
params_to_optimize = params
|
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
|
||||||
n_params = 0
|
|
||||||
for p in params:
|
|
||||||
n_params += p.numel()
|
|
||||||
|
|
||||||
accelerator.print(f"number of models: {len(training_models)}")
|
|
||||||
accelerator.print(f"number of trainable parameters: {n_params}")
|
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
|
||||||
accelerator.print("prepare optimizer, data loader etc.")
|
|
||||||
_, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize)
|
|
||||||
|
|
||||||
# dataloaderを準備する
|
|
||||||
|
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
|
||||||
train_dataset_group,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=True,
|
|
||||||
collate_fn=collator,
|
|
||||||
num_workers=n_workers,
|
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
|
||||||
if args.max_train_epochs is not None:
|
|
||||||
args.max_train_steps = args.max_train_epochs * math.ceil(
|
|
||||||
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
|
||||||
)
|
|
||||||
accelerator.print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
|
||||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
|
||||||
|
|
||||||
# lr schedulerを用意する
|
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
|
|
||||||
if args.full_fp16:
|
|
||||||
assert (
|
|
||||||
args.mixed_precision == "fp16"
|
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
|
||||||
accelerator.print("enable full fp16 training.")
|
|
||||||
prior.to(weight_dtype)
|
|
||||||
text_encoder.to(weight_dtype)
|
|
||||||
elif args.full_bf16:
|
|
||||||
assert (
|
|
||||||
args.mixed_precision == "bf16"
|
|
||||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
|
||||||
accelerator.print("enable full bf16 training.")
|
|
||||||
prior.to(weight_dtype)
|
|
||||||
text_encoder.to(weight_dtype)
|
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
|
||||||
prior, image_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
||||||
prior, image_encoder, optimizer, train_dataloader, lr_scheduler
|
|
||||||
)
|
|
||||||
(prior, image_encoder) = train_util.transform_models_if_DDP([prior, image_encoder])
|
|
||||||
|
|
||||||
text_encoder.to(weight_dtype)
|
|
||||||
text_encoder.to(accelerator.device)
|
|
||||||
image_encoder.to(weight_dtype)
|
|
||||||
image_encoder.to(accelerator.device)
|
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
|
||||||
if args.full_fp16:
|
|
||||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
|
||||||
|
|
||||||
# resumeする
|
|
||||||
train_util.resume_from_local_or_hf_if_specified(accelerator, args)
|
|
||||||
|
|
||||||
# epoch数を計算する
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
|
||||||
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
|
||||||
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
|
|
||||||
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
|
|
||||||
|
|
||||||
# 学習する
|
|
||||||
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
|
||||||
accelerator.print("running training / 学習開始")
|
|
||||||
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
|
|
||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
|
||||||
accelerator.print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
|
||||||
# accelerator.print(
|
|
||||||
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
|
|
||||||
# )
|
|
||||||
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
|
||||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
|
||||||
|
|
||||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
|
||||||
global_step = 0
|
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
init_kwargs = {}
|
|
||||||
if args.log_tracker_config is not None:
|
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
|
||||||
accelerator.init_trackers(
|
|
||||||
"wuerstchen_finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# workaround for DDPMWuerstchenScheduler
|
|
||||||
def add_noise(
|
|
||||||
scheduler: DDPMWuerstchenScheduler,
|
|
||||||
original_samples: torch.FloatTensor,
|
|
||||||
noise: torch.FloatTensor,
|
|
||||||
timesteps: torch.IntTensor,
|
|
||||||
) -> torch.FloatTensor:
|
|
||||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
|
||||||
alphas_cumprod_timesteps = scheduler._alpha_cumprod(timesteps, original_samples.device)
|
|
||||||
sqrt_alpha_prod = alphas_cumprod_timesteps**0.5
|
|
||||||
|
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
|
||||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
|
||||||
|
|
||||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod_timesteps) ** 0.5
|
|
||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
|
||||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
|
||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
|
||||||
|
|
||||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
|
||||||
return noisy_samples
|
|
||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
|
||||||
current_epoch.value = epoch + 1
|
|
||||||
|
|
||||||
for m in training_models:
|
|
||||||
m.train()
|
|
||||||
|
|
||||||
loss_total = 0
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
|
||||||
current_step.value = global_step
|
|
||||||
with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
|
|
||||||
input_ids = batch["input_ids"]
|
|
||||||
text_mask = batch["text_mask"]
|
|
||||||
effnet_pixel_values = batch["effnet_pixel_values"]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
input_ids = input_ids.to(accelerator.device)
|
|
||||||
text_mask = text_mask.to(accelerator.device)
|
|
||||||
prompt_embeds = get_hidden_states(
|
|
||||||
args, input_ids, text_mask, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
image_embeds = image_encoder(effnet_pixel_values)
|
|
||||||
image_embeds = image_embeds.add(1.0).div(42.0) # scale
|
|
||||||
|
|
||||||
# Sample noise that we'll add to the image_embeds
|
|
||||||
noise = torch.randn_like(image_embeds)
|
|
||||||
bsz = image_embeds.shape[0]
|
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
|
||||||
# TODO support mul/add/clump
|
|
||||||
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
# add noise to latent: This is same to Diffuzz.diffuse in diffuzz.py
|
|
||||||
# noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
|
||||||
noisy_latents = add_noise(noise_scheduler, image_embeds, noise, timesteps)
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
with accelerator.autocast():
|
|
||||||
noise_pred = prior(noisy_latents, timesteps, prompt_embeds)
|
|
||||||
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
# TODO add consistency loss
|
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
|
||||||
|
|
||||||
accelerator.backward(loss)
|
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
|
||||||
params_to_clip = []
|
|
||||||
for m in training_models:
|
|
||||||
params_to_clip.extend(m.parameters())
|
|
||||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
lr_scheduler.step()
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
|
||||||
if accelerator.sync_gradients:
|
|
||||||
progress_bar.update(1)
|
|
||||||
global_step += 1
|
|
||||||
|
|
||||||
# TODO ここでサンプルを生成する
|
|
||||||
# sample_images(
|
|
||||||
# accelerator,
|
|
||||||
# args,
|
|
||||||
# None,
|
|
||||||
# global_step,
|
|
||||||
# accelerator.device,
|
|
||||||
# vae,
|
|
||||||
# [tokenizer1, tokenizer2],
|
|
||||||
# [text_encoder, text_encoder2],
|
|
||||||
# prior,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
|
||||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
# TODO simplify to save prior only
|
|
||||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
||||||
args.pretrained_decoder_model_name_or_path,
|
|
||||||
prior_prior=accelerator.unwrap_model(prior),
|
|
||||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
|
||||||
prior_tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
ckpt_name = train_util.get_step_ckpt_name(args, "", global_step)
|
|
||||||
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, ckpt_name))
|
|
||||||
|
|
||||||
# TODO remove older saved models
|
|
||||||
|
|
||||||
current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
|
|
||||||
if args.logging_dir is not None:
|
|
||||||
logs = {"loss": current_loss}
|
|
||||||
logs["lr"] = float(lr_scheduler.get_last_lr()[0])
|
|
||||||
if (
|
|
||||||
args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower()
|
|
||||||
): # tracking d*lr value
|
|
||||||
logs["lr/d*lr"] = (
|
|
||||||
lr_scheduler.optimizers[0].param_groups[0]["d"] * lr_scheduler.optimizers[0].param_groups[0]["lr"]
|
|
||||||
)
|
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
# TODO moving averageにする
|
|
||||||
loss_total += current_loss
|
|
||||||
avr_loss = loss_total / (step + 1)
|
|
||||||
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
|
||||||
progress_bar.set_postfix(**logs)
|
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
|
||||||
break
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
|
||||||
logs = {"loss/epoch": loss_total / len(train_dataloader)}
|
|
||||||
accelerator.log(logs, step=epoch + 1)
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
if args.save_every_n_epochs is not None:
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
epoch_no = epoch + 1
|
|
||||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
|
||||||
if saving:
|
|
||||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
||||||
args.pretrained_decoder_model_name_or_path,
|
|
||||||
prior_prior=accelerator.unwrap_model(prior),
|
|
||||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
|
||||||
prior_tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
ckpt_name = train_util.get_epoch_ckpt_name(args, "", epoch)
|
|
||||||
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, ckpt_name))
|
|
||||||
|
|
||||||
# TODO remove older saved models
|
|
||||||
|
|
||||||
# TODO ここでサンプルを生成する
|
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
|
||||||
|
|
||||||
accelerator.end_training()
|
|
||||||
|
|
||||||
if args.save_state: # and is_main_process:
|
|
||||||
train_util.save_state_on_train_end(args, accelerator)
|
|
||||||
|
|
||||||
# del accelerator # この後メモリを使うのでこれは消す
|
|
||||||
|
|
||||||
if is_main_process:
|
|
||||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
||||||
args.pretrained_decoder_model_name_or_path,
|
|
||||||
prior_prior=accelerator.unwrap_model(prior),
|
|
||||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
|
||||||
prior_tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
ckpt_name = train_util.get_last_ckpt_name(args, "")
|
|
||||||
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, ckpt_name))
|
|
||||||
print("model saved.")
|
|
||||||
|
|
||||||
|
|
||||||
def setup_parser() -> argparse.ArgumentParser:
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# train_util.add_sd_models_arguments(parser)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_prior_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default="warp-ai/wuerstchen-prior",
|
|
||||||
required=False,
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--pretrained_decoder_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default="warp-ai/wuerstchen",
|
|
||||||
required=False,
|
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
||||||
)
|
|
||||||
|
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
|
||||||
train_util.add_training_arguments(parser, False)
|
|
||||||
# train_util.add_sd_saving_arguments(parser)
|
|
||||||
train_util.add_optimizer_arguments(parser)
|
|
||||||
config_util.add_config_arguments(parser)
|
|
||||||
|
|
||||||
# TODO add assertion for SD related arguments
|
|
||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
|
||||||
return parser
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = setup_parser()
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args = train_util.read_config_from_file(args, parser)
|
|
||||||
|
|
||||||
train(args)
|
|
||||||
Reference in New Issue
Block a user