mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Merge pull request #553 from kohya-ss/dev
no caption warning, network merging before training
This commit is contained in:
18
README.md
18
README.md
@@ -140,6 +140,24 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
### 31 May 2023, 2023/05/31
|
||||
|
||||
- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin!
|
||||
- Warning is also displayed when using class+identifier dataset. Please ignore if it is intended.
|
||||
- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru!
|
||||
- `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge.
|
||||
- `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used.
|
||||
- This is useful for incremental learning. See PR for details.
|
||||
- Show warning and continue training when uploading to HuggingFace fails.
|
||||
|
||||
- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。
|
||||
- class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。
|
||||
- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。
|
||||
- `--base_weights` オプションでLoRA等のモデルファイル(複数可)を指定すると、それらの重みをマージします。
|
||||
- `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。
|
||||
- 差分追加学習などにご利用ください。詳細はPRをご覧ください。
|
||||
- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。
|
||||
|
||||
### 25 May 2023, 2023/05/25
|
||||
|
||||
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds!
|
||||
|
||||
@@ -6,9 +6,7 @@ import os
|
||||
from library.utils import fire_in_thread
|
||||
|
||||
|
||||
def exists_repo(
|
||||
repo_id: str, repo_type: str, revision: str = "main", token: str = None
|
||||
):
|
||||
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
|
||||
api = HfApi(
|
||||
token=token,
|
||||
)
|
||||
@@ -32,27 +30,35 @@ def upload(
|
||||
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
|
||||
api = HfApi(token=token)
|
||||
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
try:
|
||||
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
|
||||
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (
|
||||
isinstance(src, Path) and src.is_dir()
|
||||
)
|
||||
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
|
||||
|
||||
def uploader():
|
||||
if is_folder:
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
folder_path=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
else:
|
||||
api.upload_file(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
path_or_fileobj=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
try:
|
||||
if is_folder:
|
||||
api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
folder_path=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
else:
|
||||
api.upload_file(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
path_or_fileobj=src,
|
||||
path_in_repo=path_in_repo,
|
||||
)
|
||||
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
|
||||
print("===========================================")
|
||||
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
|
||||
print("===========================================")
|
||||
|
||||
if args.async_upload and not force_sync_upload:
|
||||
fire_in_thread(uploader)
|
||||
@@ -71,7 +77,5 @@ def list_dir(
|
||||
token=token,
|
||||
)
|
||||
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
|
||||
file_list = [
|
||||
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
|
||||
]
|
||||
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
|
||||
return file_list
|
||||
|
||||
@@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset):
|
||||
self.is_reg = is_reg
|
||||
self.class_tokens = class_tokens
|
||||
self.caption_extension = caption_extension
|
||||
if self.caption_extension and not self.caption_extension.startswith("."):
|
||||
self.caption_extension = "." + self.caption_extension
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, DreamBoothSubset):
|
||||
@@ -1081,16 +1083,37 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
missing_captions = []
|
||||
for img_path in img_paths:
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
|
||||
print(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
number_of_missing_captions_to_show = 5
|
||||
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
|
||||
|
||||
print(
|
||||
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
|
||||
)
|
||||
for i, missing_caption in enumerate(missing_captions):
|
||||
if i >= number_of_missing_captions_to_show:
|
||||
print(missing_caption + f"... and {remaining_missing_captions} more")
|
||||
break
|
||||
print(missing_caption)
|
||||
return img_paths, captions
|
||||
|
||||
print("prepare images.")
|
||||
|
||||
@@ -148,6 +148,30 @@ def train(args):
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
if args.base_weights is not None:
|
||||
# base_weights が指定されている場合は、指定された重みを読み込みマージする
|
||||
for i, weight_path in enumerate(args.base_weights):
|
||||
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
|
||||
multiplier = 1.0
|
||||
else:
|
||||
multiplier = args.base_weights_multiplier[i]
|
||||
|
||||
print(f"merging module: {weight_path} with multiplier {multiplier}")
|
||||
|
||||
module, weights_sd = network_module.create_network_from_weights(
|
||||
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
|
||||
)
|
||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||
|
||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -163,12 +187,6 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# prepare network
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
print("import network module:", args.network_module)
|
||||
network_module = importlib.import_module(args.network_module)
|
||||
|
||||
net_kwargs = {}
|
||||
if args.network_args is not None:
|
||||
for net_arg in args.network_args:
|
||||
@@ -770,6 +788,20 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base_weights_multiplier",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user