mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
feat: update fine tuning dataset
This commit is contained in:
@@ -14,7 +14,7 @@ from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, AutoModelForCausalLM
|
||||
|
||||
from library import device_utils, train_util
|
||||
from library import device_utils, train_util, dataset_metadata_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -91,7 +91,7 @@ def main(args):
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = tagger_utils.load_metadata(args.metadata)
|
||||
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
@@ -13,6 +13,7 @@ from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from library import dataset_metadata_utils
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -384,7 +385,7 @@ def main(args):
|
||||
|
||||
# load metadata if needed
|
||||
if args.metadata is not None:
|
||||
metadata = tagger_utils.load_metadata(args.metadata)
|
||||
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
|
||||
images_metadata = metadata["images"]
|
||||
else:
|
||||
images_metadata = metadata = None
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import train_util
|
||||
from library import dataset_metadata_utils, train_util
|
||||
|
||||
|
||||
class ArchiveImageLoader:
|
||||
@@ -72,7 +72,9 @@ class ArchiveImageLoader:
|
||||
break
|
||||
|
||||
file = self.files[self.image_index]
|
||||
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
|
||||
archive_and_image_path = (
|
||||
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
|
||||
)
|
||||
self.image_index += 1
|
||||
|
||||
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
|
||||
@@ -133,29 +135,6 @@ class ImageLoader:
|
||||
return [(image_path, image, size) for image_path, (image, size) in images]
|
||||
|
||||
|
||||
def load_metadata(metadata_file: str):
|
||||
if os.path.exists(metadata_file):
|
||||
logger.info(f"loading metadata file: {metadata_file}")
|
||||
with open(metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
if major > 1 or (major == 1 and minor > 0):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
|
||||
metadata = {"format_version": "1.0.0", "images": {}}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def add_archive_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
|
||||
58
library/dataset_metadata_utils.py
Normal file
58
library/dataset_metadata_utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import os
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
METADATA_VERSION = [1, 0, 0]
|
||||
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
|
||||
|
||||
ARCHIVE_PATH_SEPARATOR = "////"
|
||||
|
||||
|
||||
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
|
||||
if os.path.exists(metadata_file):
|
||||
logger.info(f"loading metadata file: {metadata_file}")
|
||||
with open(metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
# version check
|
||||
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
|
||||
major, minor, patch = int(major), int(minor), int(patch)
|
||||
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
|
||||
logger.warning(
|
||||
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
|
||||
)
|
||||
|
||||
if "images" not in metadata:
|
||||
metadata["images"] = {}
|
||||
else:
|
||||
if not create_new:
|
||||
return None
|
||||
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
|
||||
metadata = {"format_version": VERSION_STRING, "images": {}}
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def is_archive_path(archive_and_image_path: str) -> bool:
|
||||
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
|
||||
|
||||
|
||||
def get_inner_path(archive_and_image_path: str) -> str:
|
||||
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
|
||||
|
||||
|
||||
def get_archive_digest(archive_and_image_path: str) -> str:
|
||||
"""
|
||||
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
|
||||
"""
|
||||
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
|
||||
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import utils
|
||||
from library import dataset_metadata_utils, utils
|
||||
|
||||
|
||||
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
|
||||
@@ -648,8 +648,23 @@ class LatentsCachingStrategy:
|
||||
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
|
||||
return int(w), int(h)
|
||||
|
||||
def get_latents_cache_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
|
||||
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
|
||||
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
|
||||
|
||||
def get_latents_cache_path(
|
||||
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
|
||||
) -> str:
|
||||
if cache_dir is not None:
|
||||
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
|
||||
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
|
||||
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
|
||||
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
|
||||
else:
|
||||
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
|
||||
else:
|
||||
cache_file_base = absolute_path_or_archive_img_path
|
||||
|
||||
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
|
||||
|
||||
def is_disk_cached_latents_expected(
|
||||
self,
|
||||
|
||||
@@ -12,17 +12,7 @@ import pathlib
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union
|
||||
)
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||
import glob
|
||||
import math
|
||||
@@ -69,7 +59,7 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
AutoencoderKL,
|
||||
)
|
||||
from library import custom_train_functions, sd3_utils
|
||||
from library import custom_train_functions, dataset_metadata_utils, sd3_utils
|
||||
from library.original_unet import UNet2DConditionModel
|
||||
from huggingface_hub import hf_hub_download
|
||||
import numpy as np
|
||||
@@ -1071,8 +1061,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
# check disk cache exists and size of latents
|
||||
if caching_strategy.cache_to_disk:
|
||||
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
|
||||
info.latents_cache_path = caching_strategy.get_latents_cache_path(info.absolute_path, info.image_size)
|
||||
info.latents_cache_path = caching_strategy.get_latents_cache_path_from_info(info)
|
||||
|
||||
# if the modulo of num_processes is not equal to process_index, skip caching
|
||||
# this makes each process cache different latents
|
||||
@@ -1389,6 +1378,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
input_ids = None
|
||||
caption = ""
|
||||
|
||||
caption_index = None
|
||||
if image_info.text_encoder_outputs is not None:
|
||||
# cached on memory
|
||||
text_encoder_outputs = image_info.text_encoder_outputs
|
||||
@@ -1398,11 +1388,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
text_encoder_outputs = random.choices(text_encoder_outputs, weights=image_info.caption_weights)[0]
|
||||
elif image_info.text_encoder_outputs_cache_path is not None:
|
||||
# on disk
|
||||
index = 0
|
||||
if len(image_info.captions) > 1:
|
||||
index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
|
||||
# captions_weights may be None
|
||||
caption_index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
|
||||
else:
|
||||
caption_index = 0
|
||||
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_from_disk(
|
||||
image_info.text_encoder_outputs_cache_path, index
|
||||
image_info.text_encoder_outputs_cache_path, caption_index
|
||||
)
|
||||
else:
|
||||
tokenization_required = True
|
||||
@@ -1412,8 +1404,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
caption = ""
|
||||
tags = None # None if no tags in dataset metadata or Dreambooth method is used
|
||||
if image_info.captions is not None and len(image_info.captions) > 0:
|
||||
# captions_weights may be None
|
||||
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
|
||||
if caption_index is not None: # partially cached
|
||||
caption = image_info.captions[caption_index]
|
||||
else:
|
||||
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
|
||||
if image_info.list_of_tags is not None and len(image_info.list_of_tags) > 0:
|
||||
tags = random.choices(image_info.list_of_tags, weights=image_info.tags_weights)[0]
|
||||
|
||||
@@ -1582,7 +1576,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
def read_caption(img_path, caption_extension, enable_wildcard):
|
||||
def read_caption(img_path, caption_extension) -> Optional[list[str]]:
|
||||
# captionの候補ファイル名を作る
|
||||
base_name = os.path.splitext(img_path)[0]
|
||||
base_name_face_det = base_name
|
||||
@@ -1591,7 +1585,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
base_name_face_det = "_".join(tokens[:-4])
|
||||
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
|
||||
|
||||
caption = None
|
||||
captions = None
|
||||
for cap_path in cap_paths:
|
||||
if os.path.isfile(cap_path):
|
||||
with open(cap_path, "rt", encoding="utf-8") as f:
|
||||
@@ -1601,14 +1595,12 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
|
||||
raise e
|
||||
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
|
||||
if enable_wildcard:
|
||||
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
|
||||
else:
|
||||
caption = lines[0].strip()
|
||||
captions = [line.strip() for line in lines]
|
||||
captions = [cap for cap in captions if cap != ""] # remove empty lines
|
||||
break
|
||||
return caption
|
||||
return captions
|
||||
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset):
|
||||
def load_dreambooth_dir(subset: DreamBoothSubset) -> Tuple[list[str], list[list[str]], list[list[int]]]:
|
||||
if not os.path.isdir(subset.image_dir):
|
||||
logger.warning(f"not directory: {subset.image_dir}")
|
||||
return [], [], []
|
||||
@@ -1627,7 +1619,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
use_cached_info_for_subset = False
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
|
||||
# json: {`img_path`:{"caption": ["caption...", ...] , "resolution": [width, height]}, ...}
|
||||
with open(info_cache_file, "r", encoding="utf-8") as f:
|
||||
metas = json.load(f)
|
||||
img_paths = list(metas.keys())
|
||||
@@ -1676,28 +1668,31 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
if use_cached_info_for_subset:
|
||||
captions = [meta["caption"] for meta in metas.values()]
|
||||
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
|
||||
list_of_captions = [meta["caption"] for meta in metas.values()] # list of captions for each image
|
||||
missing_captions = [
|
||||
img_path for img_path, caption in zip(img_paths, list_of_captions) if caption is None or caption == ""
|
||||
]
|
||||
else:
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
list_of_captions: list[list[str]] = []
|
||||
missing_captions = []
|
||||
for img_path in tqdm(img_paths, desc="read caption"):
|
||||
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
|
||||
if cap_for_img is None and subset.class_tokens is None:
|
||||
caps_for_img = read_caption(img_path, subset.caption_extension)
|
||||
if caps_for_img is None and subset.class_tokens is None:
|
||||
logger.warning(
|
||||
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
|
||||
)
|
||||
captions.append("")
|
||||
list_of_captions.append([""])
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
if cap_for_img is None:
|
||||
captions.append(subset.class_tokens)
|
||||
if caps_for_img is None:
|
||||
list_of_captions.append([subset.class_tokens])
|
||||
missing_captions.append(img_path)
|
||||
else:
|
||||
captions.append(cap_for_img)
|
||||
list_of_captions.append(caps_for_img)
|
||||
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
for captions in list_of_captions:
|
||||
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
|
||||
|
||||
if missing_captions:
|
||||
number_of_missing_captions = len(missing_captions)
|
||||
@@ -1717,14 +1712,14 @@ class DreamBoothDataset(BaseDataset):
|
||||
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
|
||||
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
|
||||
matas = {}
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
matas[img_path] = {"caption": caption, "resolution": list(size)}
|
||||
for img_path, captions, size in zip(img_paths, list_of_captions, sizes):
|
||||
matas[img_path] = {"caption": captions, "resolution": list(size)}
|
||||
with open(info_cache_file, "w", encoding="utf-8") as f:
|
||||
json.dump(matas, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
|
||||
|
||||
# if sizes are not set, image size will be read in make_buckets
|
||||
return img_paths, captions, sizes
|
||||
return img_paths, list_of_captions, sizes
|
||||
|
||||
logger.info("prepare images.")
|
||||
num_train_images = 0
|
||||
@@ -1743,7 +1738,7 @@ class DreamBoothDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
img_paths, captions, sizes = load_dreambooth_dir(subset)
|
||||
img_paths, list_of_captions, sizes = load_dreambooth_dir(subset)
|
||||
if len(img_paths) < 1:
|
||||
logger.warning(
|
||||
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
|
||||
@@ -1755,9 +1750,10 @@ class DreamBoothDataset(BaseDataset):
|
||||
else:
|
||||
num_train_images += subset.num_repeats * len(img_paths)
|
||||
|
||||
for img_path, caption, size in zip(img_paths, captions, sizes):
|
||||
captions = caption.split("\n") # empty line is allowed
|
||||
info = ImageInfo(img_path, subset.num_repeats, captions, subset.is_reg, img_path)
|
||||
for img_path, captions, size in zip(img_paths, list_of_captions, sizes):
|
||||
# NOTE: captions of DreamBoothDataset is treated as tags. Because shuffle, drop, etc. are applied to them.
|
||||
info = ImageInfo(img_path, subset.num_repeats, subset.is_reg, img_path)
|
||||
info.list_of_tags = captions
|
||||
if size is not None:
|
||||
info.image_size = size
|
||||
if subset.is_reg:
|
||||
@@ -1817,6 +1813,9 @@ class FineTuningDataset(BaseDataset):
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
|
||||
latent_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
info_for_image_in_archive_is_shown = False
|
||||
for subset in subsets:
|
||||
if subset.num_repeats < 1:
|
||||
logger.warning(
|
||||
@@ -1831,195 +1830,76 @@ class FineTuningDataset(BaseDataset):
|
||||
continue
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = dataset_metadata_utils.load_metadata(subset.metadata_file)
|
||||
if metadata is None:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
images_metadata = metadata.get("images")
|
||||
|
||||
if len(metadata) < 1:
|
||||
if images_metadata is None or len(images_metadata) == 0:
|
||||
logger.warning(
|
||||
f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します"
|
||||
)
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
default_latent_cache_dir = metadata.get("latent_cache_dir")
|
||||
for image_key, img_md in images_metadata.items():
|
||||
latent_cache_dir = img_md.get("latent_cache_dir", default_latent_cache_dir)
|
||||
image_size = img_md.get("image_size")
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
# check if image_key is valid
|
||||
is_valid = False
|
||||
if dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR in image_key:
|
||||
if latent_cache_dir is not None:
|
||||
latent_cache_path = latent_caching_strategy.get_latents_cache_path(image_key, image_size, latent_cache_dir)
|
||||
is_valid = os.path.exists(latent_cache_path)
|
||||
if not is_valid:
|
||||
# if latent_cache_dir is not specified or cache file does not exist, we cannot check the existence of image
|
||||
if not info_for_image_in_archive_is_shown:
|
||||
logger.info(
|
||||
f"image in archive is not checked for existence / アーカイブ内の画像は存在チェックができません。"
|
||||
)
|
||||
info_for_image_in_archive_is_shown = True
|
||||
is_valid = True
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
# check existence of image file or cache file
|
||||
is_valid = os.path.exists(image_key)
|
||||
if not is_valid:
|
||||
latent_cache_path = latent_caching_strategy.get_latents_cache_path(image_key, image_size, latent_cache_dir)
|
||||
is_valid = os.path.exists(latent_cache_path)
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
assert is_valid, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
captions: Optional[list[str]] = img_md.get("caption")
|
||||
caption_weights: Optional[list[float]] = img_md.get("caption_weights")
|
||||
list_of_tags: Optional[list[str]] = img_md.get("tags")
|
||||
tags_weights: Optional[list[float]] = img_md.get("tags_weights")
|
||||
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
if caption is None:
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if cache exists, use them
|
||||
image_info.latents_cache_path, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, False, image_key)
|
||||
image_info.set_fine_tuning_info(captions, caption_weights, list_of_tags, tags_weights, latent_cache_dir, image_size)
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
|
||||
if list_of_tags is not None:
|
||||
self.set_tag_frequency(os.path.basename(subset.metadata_file), list_of_tags)
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_cache_path is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
use_npz_latents = False
|
||||
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
resos = set()
|
||||
for image_info in self.image_data.values():
|
||||
if image_info.image_size is None:
|
||||
sizes = None # not calculated
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
|
||||
)
|
||||
|
||||
assert (
|
||||
resolution is not None
|
||||
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
|
||||
assert (
|
||||
not bucket_no_upscale
|
||||
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_cache_path = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
|
||||
@@ -22,17 +22,16 @@ def fire_in_thread(f, *args, **kwargs):
|
||||
|
||||
|
||||
class ImageInfo:
|
||||
def __init__(
|
||||
self, image_key: str, num_repeats: int, captions: Optional[Union[str, list[str]]], is_reg: bool, absolute_path: str
|
||||
) -> None:
|
||||
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
|
||||
self.image_key: str = image_key
|
||||
self.num_repeats: int = num_repeats
|
||||
self.captions: Optional[list[str]] = None if captions is None else ([captions] if isinstance(captions, str) else captions)
|
||||
self.captions: Optional[list[str]] = None
|
||||
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
|
||||
self.list_of_tags: Optional[list[str]] = None
|
||||
self.tags_weights: Optional[list[float]] = None
|
||||
self.is_reg: bool = is_reg
|
||||
self.absolute_path: str = absolute_path
|
||||
self.latents_cache_dir: Optional[str] = None
|
||||
self.image_size: Tuple[int, int] = None
|
||||
self.resized_size: Tuple[int, int] = None
|
||||
self.bucket_reso: Tuple[int, int] = None
|
||||
@@ -55,6 +54,28 @@ class ImageInfo:
|
||||
|
||||
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
|
||||
|
||||
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
|
||||
self.list_of_tags = list_of_tags
|
||||
|
||||
def set_fine_tuning_info(
|
||||
self,
|
||||
captions: Optional[list[str]],
|
||||
caption_weights: Optional[list[float]],
|
||||
list_of_tags: Optional[list[str]],
|
||||
tags_weights: Optional[list[float]],
|
||||
image_size: Tuple[int, int],
|
||||
latents_cache_dir: Optional[str],
|
||||
):
|
||||
self.captions = captions
|
||||
self.caption_weights = caption_weights
|
||||
self.list_of_tags = list_of_tags
|
||||
self.tags_weights = tags_weights
|
||||
self.image_size = image_size
|
||||
self.latents_cache_dir = latents_cache_dir
|
||||
|
||||
|
||||
# region Logging
|
||||
|
||||
|
||||
Reference in New Issue
Block a user