Merge pull request #553 from kohya-ss/dev

no caption warning, network merging before training
This commit is contained in:
Kohya S
2023-05-31 21:04:50 +09:00
committed by GitHub
4 changed files with 109 additions and 32 deletions

View File

@@ -140,6 +140,24 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### 31 May 2023, 2023/05/31
- Show warning when image caption file does not exist during training. [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) Thanks to TingTingin!
- Warning is also displayed when using class+identifier dataset. Please ignore if it is intended.
- `train_network.py` now supports merging network weights before training. [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) Thanks to u-haru!
- `--base_weights` option specifies LoRA or other model files (multiple files are allowed) to merge.
- `--base_weights_multiplier` option specifies multiplier of the weights to merge (multiple values are allowed). If omitted or less than `base_weights`, 1.0 is used.
- This is useful for incremental learning. See PR for details.
- Show warning and continue training when uploading to HuggingFace fails.
- 学習時に画像のキャプションファイルが存在しない場合、警告が表示されるようになりました。 [PR #533](https://github.com/kohya-ss/sd-scripts/pull/533) TingTingin氏に感謝します。
- class+identifier方式のデータセットを利用している場合も警告が表示されます。意図している通りの場合は無視してください。
- `train_network.py` に学習前にモデルにnetworkの重みをマージする機能が追加されました。 [PR #542](https://github.com/kohya-ss/sd-scripts/pull/542) u-haru氏に感謝します。
- `--base_weights` オプションでLoRA等のモデルファイル複数可を指定すると、それらの重みをマージします。
- `--base_weights_multiplier` オプションでマージする重みの倍率(複数可)を指定できます。省略時または`base_weights`よりも数が少ない場合は1.0になります。
- 差分追加学習などにご利用ください。詳細はPRをご覧ください。
- HuggingFaceへのアップロードに失敗した場合、警告を表示しそのまま学習を続行するよう変更しました。
### 25 May 2023, 2023/05/25
- [D-Adaptation v3.0](https://github.com/facebookresearch/dadaptation) is now supported. [PR #530](https://github.com/kohya-ss/sd-scripts/pull/530) Thanks to sdbds!

View File

@@ -6,9 +6,7 @@ import os
from library.utils import fire_in_thread
def exists_repo(
repo_id: str, repo_type: str, revision: str = "main", token: str = None
):
def exists_repo(repo_id: str, repo_type: str, revision: str = "main", token: str = None):
api = HfApi(
token=token,
)
@@ -32,27 +30,35 @@ def upload(
private = args.huggingface_repo_visibility is None or args.huggingface_repo_visibility != "public"
api = HfApi(token=token)
if not exists_repo(repo_id=repo_id, repo_type=repo_type, token=token):
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
try:
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private)
except Exception as e: # とりあえずRepositoryNotFoundErrorは確認したが他にあると困るので
print("===========================================")
print(f"failed to create HuggingFace repo / HuggingFaceのリポジトリの作成に失敗しました : {e}")
print("===========================================")
is_folder = (type(src) == str and os.path.isdir(src)) or (
isinstance(src, Path) and src.is_dir()
)
is_folder = (type(src) == str and os.path.isdir(src)) or (isinstance(src, Path) and src.is_dir())
def uploader():
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
try:
if is_folder:
api.upload_folder(
repo_id=repo_id,
repo_type=repo_type,
folder_path=src,
path_in_repo=path_in_repo,
)
else:
api.upload_file(
repo_id=repo_id,
repo_type=repo_type,
path_or_fileobj=src,
path_in_repo=path_in_repo,
)
except Exception as e: # RuntimeErrorを確認済みだが他にあると困るので
print("===========================================")
print(f"failed to upload to HuggingFace / HuggingFaceへのアップロードに失敗しました : {e}")
print("===========================================")
if args.async_upload and not force_sync_upload:
fire_in_thread(uploader)
@@ -71,7 +77,5 @@ def list_dir(
token=token,
)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type=repo_type)
file_list = [
file for file in repo_info.siblings if file.rfilename.startswith(subfolder)
]
file_list = [file for file in repo_info.siblings if file.rfilename.startswith(subfolder)]
return file_list

View File

@@ -348,6 +348,8 @@ class DreamBoothSubset(BaseSubset):
self.is_reg = is_reg
self.class_tokens = class_tokens
self.caption_extension = caption_extension
if self.caption_extension and not self.caption_extension.startswith("."):
self.caption_extension = "." + self.caption_extension
def __eq__(self, other) -> bool:
if not isinstance(other, DreamBoothSubset):
@@ -1081,16 +1083,37 @@ class DreamBoothDataset(BaseDataset):
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
missing_captions = []
for img_path in img_paths:
cap_for_img = read_caption(img_path, subset.caption_extension)
if cap_for_img is None and subset.class_tokens is None:
print(f"neither caption file nor class tokens are found. use empty caption for {img_path}")
print(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
missing_captions.append(img_path)
else:
captions.append(subset.class_tokens if cap_for_img is None else cap_for_img)
if cap_for_img is None:
captions.append(subset.class_tokens)
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
number_of_missing_captions_to_show = 5
remaining_missing_captions = number_of_missing_captions - number_of_missing_captions_to_show
print(
f"No caption file found for {number_of_missing_captions} images. Training will continue without captions for these images. If class token exists, it will be used. / {number_of_missing_captions}枚の画像にキャプションファイルが見つかりませんでした。これらの画像についてはキャプションなしで学習を続行します。class tokenが存在する場合はそれを使います。"
)
for i, missing_caption in enumerate(missing_captions):
if i >= number_of_missing_captions_to_show:
print(missing_caption + f"... and {remaining_missing_captions} more")
break
print(missing_caption)
return img_paths, captions
print("prepare images.")

View File

@@ -148,6 +148,30 @@ def train(args):
# モデルに xformers とか memory efficient attention を組み込む
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
# 差分追加学習のためにモデルを読み込む
import sys
sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
if args.base_weights is not None:
# base_weights が指定されている場合は、指定された重みを読み込みマージする
for i, weight_path in enumerate(args.base_weights):
if args.base_weights_multiplier is None or len(args.base_weights_multiplier) <= i:
multiplier = 1.0
else:
multiplier = args.base_weights_multiplier[i]
print(f"merging module: {weight_path} with multiplier {multiplier}")
module, weights_sd = network_module.create_network_from_weights(
multiplier, weight_path, vae, text_encoder, unet, for_inference=True
)
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
print(f"all weights merged: {', '.join(args.base_weights)}")
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
@@ -163,12 +187,6 @@ def train(args):
accelerator.wait_for_everyone()
# prepare network
import sys
sys.path.append(os.path.dirname(__file__))
print("import network module:", args.network_module)
network_module = importlib.import_module(args.network_module)
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
@@ -770,6 +788,20 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="automatically determine dim (rank) from network_weights / dim (rank)をnetwork_weightsで指定した重みから自動で決定する",
)
parser.add_argument(
"--base_weights",
type=str,
default=None,
nargs="*",
help="network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みファイル",
)
parser.add_argument(
"--base_weights_multiplier",
type=float,
default=None,
nargs="*",
help="multiplier for network weights to merge into the model before training / 学習前にあらかじめモデルにマージするnetworkの重みの倍率",
)
return parser