mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
feat: update fine tuning dataset
This commit is contained in:
@@ -12,17 +12,7 @@ import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
import glob
|
||||
import math
|
||||
@@ -69,7 +59,7 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
AutoencoderKL,
|
||||
)
|
||||
from library import custom_train_functions, sd3_utils
|
||||
from library import custom_train_functions, dataset_metadata_utils, sd3_utils
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
@@ -1071,8 +1061,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_cache_path = caching_strategy.get_latents_cache_path(info.absolute_path, info.image_size)
|
||||
info.latents_cache_path = caching_strategy.get_latents_cache_path_from_info(info)
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
@@ -1389,6 +1378,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids = None
|
||||
caption = ""
|
||||
|
||||
caption_index = None
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached on memory
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
@@ -1398,11 +1388,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs = random.choices(text_encoder_outputs, weights=image_info.caption_weights)[0]
|
||||
elif image_info.text_encoder_outputs_cache_path is not None:
|
||||
# on disk
|
||||
index = 0
|
||||
if len(image_info.captions) > 1:
|
||||
index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
|
||||
# captions_weights may be None
|
||||
caption_index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
|
||||
else:
|
||||
caption_index = 0
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_from_disk(
|
||||
image_info.text_encoder_outputs_cache_path, index
|
||||
image_info.text_encoder_outputs_cache_path, caption_index
|
||||
)
|
||||
else:
|
||||
tokenization_required = True
|
||||
@@ -1412,8 +1404,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caption = ""
|
||||
tags = None # None if no tags in dataset metadata or Dreambooth method is used
|
||||
if image_info.captions is not None and len(image_info.captions) > 0:
|
||||
# captions_weights may be None
|
||||
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
|
||||
if caption_index is not None: # partially cached
|
||||
caption = image_info.captions[caption_index]
|
||||
else:
|
||||
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
|
||||
if image_info.list_of_tags is not None and len(image_info.list_of_tags) > 0:
|
||||
tags = random.choices(image_info.list_of_tags, weights=image_info.tags_weights)[0]
|
||||
|
||||
@@ -1582,7 +1576,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
def read_caption(img_path, caption_extension) -> Optional[list[str]]:
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -1591,7 +1585,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
captions = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
@@ -1601,14 +1595,12 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
captions = [line.strip() for line in lines]
|
||||
captions = [cap for cap in captions if cap != ""] # remove empty lines
|
||||
break
|
||||
return caption
|
||||
return captions
|
||||
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset) -> Tuple[list[str], list[list[str]], list[list[int]]]:
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], [], []
|
||||
@@ -1627,7 +1619,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
use_cached_info_for_subset = False
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
|
||||
# json: {`img_path`:{"caption": ["caption...", ...] , "resolution": [width, height]}, ...}
|
||||
with open(info_cache_file, "r", encoding="utf-8") as f:
|
||||
metas = json.load(f)
|
||||
img_paths = list(metas.keys())
|
||||
@@ -1676,28 +1668,31 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
list_of_captions = [meta["caption"] for meta in metas.values()] # list of captions for each image
|
||||
missing_captions = [
|
||||
img_path for img_path, caption in zip(img_paths, list_of_captions) if caption is None or caption == ""
|
||||
]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
list_of_captions: list[list[str]] = []
|
||||
missing_captions = []
|
||||
for img_path in tqdm(img_paths, desc="read caption"):
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
caps_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if caps_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
list_of_captions.append([""])
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
if caps_for_img is None:
|
||||
list_of_captions.append([subset.class_tokens])
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
list_of_captions.append(caps_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
for captions in list_of_captions:
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
@@ -1717,14 +1712,14 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
|
||||
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
|
||||
matas = {}
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
matas[img_path] = {"caption": caption, "resolution": list(size)}
|
||||
for img_path, captions, size in zip(img_paths, list_of_captions, sizes):
|
||||
matas[img_path] = {"caption": captions, "resolution": list(size)}
|
||||
with open(info_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(matas, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
|
||||
|
||||
# if sizes are not set, image size will be read in make_buckets
|
||||
return img_paths, captions, sizes
|
||||
return img_paths, list_of_captions, sizes
|
||||
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
@@ -1743,7 +1738,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, captions, sizes = load_dreambooth_dir(subset)
|
||||
img_paths, list_of_captions, sizes = load_dreambooth_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
||||
@@ -1755,9 +1750,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
captions = caption.split("\n") # empty line is allowed
|
||||
info = ImageInfo(img_path, subset.num_repeats, captions, subset.is_reg, img_path)
|
||||
for img_path, captions, size in zip(img_paths, list_of_captions, sizes):
|
||||
# NOTE: captions of DreamBoothDataset is treated as tags. Because shuffle, drop, etc. are applied to them.
|
||||
info = ImageInfo(img_path, subset.num_repeats, subset.is_reg, img_path)
|
||||
info.list_of_tags = captions
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -1817,6 +1813,9 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
|
||||
latent_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
info_for_image_in_archive_is_shown = False
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
@@ -1831,195 +1830,76 @@ class FineTuningDataset(BaseDataset):
|
||||
continue
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = dataset_metadata_utils.load_metadata(subset.metadata_file)
|
||||
if metadata is None:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
images_metadata = metadata.get("images")
|
||||
|
||||
if len(metadata) < 1:
|
||||
if images_metadata is None or len(images_metadata) == 0:
|
||||
logger.warning(
|
||||
f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します"
|
||||
)
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
default_latent_cache_dir = metadata.get("latent_cache_dir")
|
||||
for image_key, img_md in images_metadata.items():
|
||||
latent_cache_dir = img_md.get("latent_cache_dir", default_latent_cache_dir)
|
||||
image_size = img_md.get("image_size")
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
# check if image_key is valid
|
||||
is_valid = False
|
||||
if dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR in image_key:
|
||||
if latent_cache_dir is not None:
|
||||
latent_cache_path = latent_caching_strategy.get_latents_cache_path(image_key, image_size, latent_cache_dir)
|
||||
is_valid = os.path.exists(latent_cache_path)
|
||||
if not is_valid:
|
||||
# if latent_cache_dir is not specified or cache file does not exist, we cannot check the existence of image
|
||||
if not info_for_image_in_archive_is_shown:
|
||||
logger.info(
|
||||
f"image in archive is not checked for existence / アーカイブ内の画像は存在チェックができません。"
|
||||
)
|
||||
info_for_image_in_archive_is_shown = True
|
||||
is_valid = True
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
# check existence of image file or cache file
|
||||
is_valid = os.path.exists(image_key)
|
||||
if not is_valid:
|
||||
latent_cache_path = latent_caching_strategy.get_latents_cache_path(image_key, image_size, latent_cache_dir)
|
||||
is_valid = os.path.exists(latent_cache_path)
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
assert is_valid, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
captions: Optional[list[str]] = img_md.get("caption")
|
||||
caption_weights: Optional[list[float]] = img_md.get("caption_weights")
|
||||
list_of_tags: Optional[list[str]] = img_md.get("tags")
|
||||
tags_weights: Optional[list[float]] = img_md.get("tags_weights")
|
||||
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
if caption is None:
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if cache exists, use them
|
||||
image_info.latents_cache_path, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, False, image_key)
|
||||
image_info.set_fine_tuning_info(captions, caption_weights, list_of_tags, tags_weights, latent_cache_dir, image_size)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
||||
if list_of_tags is not None:
|
||||
self.set_tag_frequency(os.path.basename(subset.metadata_file), list_of_tags)
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_cache_path is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
use_npz_latents = False
|
||||
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
resos = set()
|
||||
for image_info in self.image_data.values():
|
||||
if image_info.image_size is None:
|
||||
sizes = None # not calculated
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
|
||||
)
|
||||
|
||||
assert (
|
||||
resolution is not None
|
||||
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
|
||||
assert (
|
||||
not bucket_no_upscale
|
||||
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_cache_path = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
|
||||
Reference in New Issue
Block a user