feat: update fine tuning dataset

This commit is contained in:
Kohya S
2024-12-09 20:52:18 +09:00
parent 70423ec61d
commit f2322a23e2
7 changed files with 202 additions and 248 deletions

View File

@@ -14,7 +14,7 @@ from PIL import Image
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM
from library import device_utils, train_util
from library import device_utils, train_util, dataset_metadata_utils
from library.utils import setup_logging
setup_logging()
@@ -91,7 +91,7 @@ def main(args):
# load metadata if needed
if args.metadata is not None:
metadata = tagger_utils.load_metadata(args.metadata)
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None

View File

@@ -13,6 +13,7 @@ from huggingface_hub import hf_hub_download
from PIL import Image
from tqdm import tqdm
from library import dataset_metadata_utils
from library.utils import setup_logging
setup_logging()
@@ -384,7 +385,7 @@ def main(args):
# load metadata if needed
if args.metadata is not None:
metadata = tagger_utils.load_metadata(args.metadata)
metadata = dataset_metadata_utils.load_metadata(args.metadata, create_new=True)
images_metadata = metadata["images"]
else:
images_metadata = metadata = None

View File

@@ -16,7 +16,7 @@ import logging
logger = logging.getLogger(__name__)
from library import train_util
from library import dataset_metadata_utils, train_util
class ArchiveImageLoader:
@@ -72,7 +72,9 @@ class ArchiveImageLoader:
break
file = self.files[self.image_index]
archive_and_image_path = f"{self.archive_paths[self.archive_index]}////{file}"
archive_and_image_path = (
f"{self.archive_paths[self.archive_index]}{dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR}{file}"
)
self.image_index += 1
def load_image(file, archive: Union[zipfile.ZipFile, tarfile.TarFile]):
@@ -133,29 +135,6 @@ class ImageLoader:
return [(image_path, image, size) for image_path, (image, size) in images]
def load_metadata(metadata_file: str):
if os.path.exists(metadata_file):
logger.info(f"loading metadata file: {metadata_file}")
with open(metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > 1 or (major == 1 and minor > 0):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version 1.0.0. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
metadata = {"format_version": "1.0.0", "images": {}}
return metadata
def add_archive_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--metadata",

View File

@@ -0,0 +1,58 @@
import os
import json
from typing import Any, Optional
from .utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
METADATA_VERSION = [1, 0, 0]
VERSION_STRING = ".".join(str(v) for v in METADATA_VERSION)
ARCHIVE_PATH_SEPARATOR = "////"
def load_metadata(metadata_file: str, create_new: bool = False) -> Optional[dict[str, Any]]:
if os.path.exists(metadata_file):
logger.info(f"loading metadata file: {metadata_file}")
with open(metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
# version check
major, minor, patch = metadata.get("format_version", "0.0.0").split(".")
major, minor, patch = int(major), int(minor), int(patch)
if major > METADATA_VERSION[0] or (major == METADATA_VERSION[0] and minor > METADATA_VERSION[1]):
logger.warning(
f"metadata format version {major}.{minor}.{patch} is higher than supported version {VERSION_STRING}. Some features may not work."
)
if "images" not in metadata:
metadata["images"] = {}
else:
if not create_new:
return None
logger.info(f"metadata file not found: {metadata_file}, creating new metadata")
metadata = {"format_version": VERSION_STRING, "images": {}}
return metadata
def is_archive_path(archive_and_image_path: str) -> bool:
return archive_and_image_path.count(ARCHIVE_PATH_SEPARATOR) == 1
def get_inner_path(archive_and_image_path: str) -> str:
return archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[1]
def get_archive_digest(archive_and_image_path: str) -> str:
"""
calculate a 8-digits hex digest for the archive path to avoid collisions for different archives with the same name.
"""
archive_path = archive_and_image_path.split(ARCHIVE_PATH_SEPARATOR, 1)[0]
return f"{hash(archive_path) & 0xFFFFFFFF:08x}"

View File

@@ -17,7 +17,7 @@ import logging
logger = logging.getLogger(__name__)
from library import utils
from library import dataset_metadata_utils, utils
def get_compatible_dtypes(dtype: Optional[Union[str, torch.dtype]]) -> List[torch.dtype]:
@@ -648,8 +648,23 @@ class LatentsCachingStrategy:
w, h = os.path.splitext(cache_path)[0].rsplit("_", 2)[-2].split("x")
return int(w), int(h)
def get_latents_cache_path(self, absolute_path: str, image_size: Tuple[int, int]) -> str:
return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def get_latents_cache_path_from_info(self, info: utils.ImageInfo) -> str:
return self.get_latents_cache_path(info.absolute_path, info.image_size, info.latents_cache_dir)
def get_latents_cache_path(
self, absolute_path_or_archive_img_path: str, image_size: Tuple[int, int], cache_dir: Optional[str] = None
) -> str:
if cache_dir is not None:
if dataset_metadata_utils.is_archive_path(absolute_path_or_archive_img_path):
inner_path = dataset_metadata_utils.get_inner_path(absolute_path_or_archive_img_path)
archive_digest = dataset_metadata_utils.get_archive_digest(absolute_path_or_archive_img_path)
cache_file_base = os.path.join(cache_dir, f"{archive_digest}_{inner_path}")
else:
cache_file_base = os.path.join(cache_dir, os.path.basename(absolute_path_or_archive_img_path))
else:
cache_file_base = absolute_path_or_archive_img_path
return os.path.splitext(cache_file_base)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.cache_suffix
def is_disk_cached_latents_expected(
self,

View File

@@ -12,17 +12,7 @@ import pathlib
import re
import shutil
import time
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union
)
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob
import math
@@ -69,7 +59,7 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
AutoencoderKL,
)
from library import custom_train_functions, sd3_utils
from library import custom_train_functions, dataset_metadata_utils, sd3_utils
from library.original_unet import UNet2DConditionModel
from huggingface_hub import hf_hub_download
import numpy as np
@@ -1071,8 +1061,7 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if caching_strategy.cache_to_disk:
# info.latents_npz = os.path.splitext(info.absolute_path)[0] + file_suffix
info.latents_cache_path = caching_strategy.get_latents_cache_path(info.absolute_path, info.image_size)
info.latents_cache_path = caching_strategy.get_latents_cache_path_from_info(info)
# if the modulo of num_processes is not equal to process_index, skip caching
# this makes each process cache different latents
@@ -1389,6 +1378,7 @@ class BaseDataset(torch.utils.data.Dataset):
input_ids = None
caption = ""
caption_index = None
if image_info.text_encoder_outputs is not None:
# cached on memory
text_encoder_outputs = image_info.text_encoder_outputs
@@ -1398,11 +1388,13 @@ class BaseDataset(torch.utils.data.Dataset):
text_encoder_outputs = random.choices(text_encoder_outputs, weights=image_info.caption_weights)[0]
elif image_info.text_encoder_outputs_cache_path is not None:
# on disk
index = 0
if len(image_info.captions) > 1:
index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
# captions_weights may be None
caption_index = random.choices(range(len(image_info.captions), weights=image_info.caption_weights))[0]
else:
caption_index = 0
text_encoder_outputs = self.text_encoder_output_caching_strategy.load_from_disk(
image_info.text_encoder_outputs_cache_path, index
image_info.text_encoder_outputs_cache_path, caption_index
)
else:
tokenization_required = True
@@ -1412,8 +1404,10 @@ class BaseDataset(torch.utils.data.Dataset):
caption = ""
tags = None # None if no tags in dataset metadata or Dreambooth method is used
if image_info.captions is not None and len(image_info.captions) > 0:
# captions_weights may be None
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
if caption_index is not None: # partially cached
caption = image_info.captions[caption_index]
else:
caption = random.choices(image_info.captions, weights=image_info.caption_weights)[0]
if image_info.list_of_tags is not None and len(image_info.list_of_tags) > 0:
tags = random.choices(image_info.list_of_tags, weights=image_info.tags_weights)[0]
@@ -1582,7 +1576,7 @@ class DreamBoothDataset(BaseDataset):
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
def read_caption(img_path, caption_extension, enable_wildcard):
def read_caption(img_path, caption_extension) -> Optional[list[str]]:
# captionの候補ファイル名を作る
base_name = os.path.splitext(img_path)[0]
base_name_face_det = base_name
@@ -1591,7 +1585,7 @@ class DreamBoothDataset(BaseDataset):
base_name_face_det = "_".join(tokens[:-4])
cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension]
caption = None
captions = None
for cap_path in cap_paths:
if os.path.isfile(cap_path):
with open(cap_path, "rt", encoding="utf-8") as f:
@@ -1601,14 +1595,12 @@ class DreamBoothDataset(BaseDataset):
logger.error(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}")
raise e
assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}"
if enable_wildcard:
caption = "\n".join([line.strip() for line in lines if line.strip() != ""]) # 空行を除く、改行で連結
else:
caption = lines[0].strip()
captions = [line.strip() for line in lines]
captions = [cap for cap in captions if cap != ""] # remove empty lines
break
return caption
return captions
def load_dreambooth_dir(subset: DreamBoothSubset):
def load_dreambooth_dir(subset: DreamBoothSubset) -> Tuple[list[str], list[list[str]], list[list[int]]]:
if not os.path.isdir(subset.image_dir):
logger.warning(f"not directory: {subset.image_dir}")
return [], [], []
@@ -1627,7 +1619,7 @@ class DreamBoothDataset(BaseDataset):
use_cached_info_for_subset = False
if use_cached_info_for_subset:
# json: {`img_path`:{"caption": "caption...", "resolution": [width, height]}, ...}
# json: {`img_path`:{"caption": ["caption...", ...] , "resolution": [width, height]}, ...}
with open(info_cache_file, "r", encoding="utf-8") as f:
metas = json.load(f)
img_paths = list(metas.keys())
@@ -1676,28 +1668,31 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
if use_cached_info_for_subset:
captions = [meta["caption"] for meta in metas.values()]
missing_captions = [img_path for img_path, caption in zip(img_paths, captions) if caption is None or caption == ""]
list_of_captions = [meta["caption"] for meta in metas.values()] # list of captions for each image
missing_captions = [
img_path for img_path, caption in zip(img_paths, list_of_captions) if caption is None or caption == ""
]
else:
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
captions = []
list_of_captions: list[list[str]] = []
missing_captions = []
for img_path in tqdm(img_paths, desc="read caption"):
cap_for_img = read_caption(img_path, subset.caption_extension, subset.enable_wildcard)
if cap_for_img is None and subset.class_tokens is None:
caps_for_img = read_caption(img_path, subset.caption_extension)
if caps_for_img is None and subset.class_tokens is None:
logger.warning(
f"neither caption file nor class tokens are found. use empty caption for {img_path} / キャプションファイルもclass tokenも見つかりませんでした。空のキャプションを使用します: {img_path}"
)
captions.append("")
list_of_captions.append([""])
missing_captions.append(img_path)
else:
if cap_for_img is None:
captions.append(subset.class_tokens)
if caps_for_img is None:
list_of_captions.append([subset.class_tokens])
missing_captions.append(img_path)
else:
captions.append(cap_for_img)
list_of_captions.append(caps_for_img)
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
for captions in list_of_captions:
self.set_tag_frequency(os.path.basename(subset.image_dir), captions) # タグ頻度を記録
if missing_captions:
number_of_missing_captions = len(missing_captions)
@@ -1717,14 +1712,14 @@ class DreamBoothDataset(BaseDataset):
logger.info(f"cache image info for / 画像情報をキャッシュします : {info_cache_file}")
sizes = [self.get_image_size(img_path) for img_path in tqdm(img_paths, desc="get image size")]
matas = {}
for img_path, caption, size in zip(img_paths, captions, sizes):
matas[img_path] = {"caption": caption, "resolution": list(size)}
for img_path, captions, size in zip(img_paths, list_of_captions, sizes):
matas[img_path] = {"caption": captions, "resolution": list(size)}
with open(info_cache_file, "w", encoding="utf-8") as f:
json.dump(matas, f, ensure_ascii=False, indent=2)
logger.info(f"cache image info done for / 画像情報を出力しました : {info_cache_file}")
# if sizes are not set, image size will be read in make_buckets
return img_paths, captions, sizes
return img_paths, list_of_captions, sizes
logger.info("prepare images.")
num_train_images = 0
@@ -1743,7 +1738,7 @@ class DreamBoothDataset(BaseDataset):
)
continue
img_paths, captions, sizes = load_dreambooth_dir(subset)
img_paths, list_of_captions, sizes = load_dreambooth_dir(subset)
if len(img_paths) < 1:
logger.warning(
f"ignore subset with image_dir='{subset.image_dir}': no images found / 画像が見つからないためサブセットを無視します"
@@ -1755,9 +1750,10 @@ class DreamBoothDataset(BaseDataset):
else:
num_train_images += subset.num_repeats * len(img_paths)
for img_path, caption, size in zip(img_paths, captions, sizes):
captions = caption.split("\n") # empty line is allowed
info = ImageInfo(img_path, subset.num_repeats, captions, subset.is_reg, img_path)
for img_path, captions, size in zip(img_paths, list_of_captions, sizes):
# NOTE: captions of DreamBoothDataset is treated as tags. Because shuffle, drop, etc. are applied to them.
info = ImageInfo(img_path, subset.num_repeats, subset.is_reg, img_path)
info.list_of_tags = captions
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -1817,6 +1813,9 @@ class FineTuningDataset(BaseDataset):
self.num_train_images = 0
self.num_reg_images = 0
latent_caching_strategy = LatentsCachingStrategy.get_strategy()
info_for_image_in_archive_is_shown = False
for subset in subsets:
if subset.num_repeats < 1:
logger.warning(
@@ -1831,195 +1830,76 @@ class FineTuningDataset(BaseDataset):
continue
# メタデータを読み込む
if os.path.exists(subset.metadata_file):
logger.info(f"loading existing metadata: {subset.metadata_file}")
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
metadata = json.load(f)
else:
metadata = dataset_metadata_utils.load_metadata(subset.metadata_file)
if metadata is None:
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
images_metadata = metadata.get("images")
if len(metadata) < 1:
if images_metadata is None or len(images_metadata) == 0:
logger.warning(
f"ignore subset with '{subset.metadata_file}': no image entries found / 画像に関するデータが見つからないためサブセットを無視します"
)
continue
tags_list = []
for image_key, img_md in metadata.items():
# path情報を作る
abs_path = None
default_latent_cache_dir = metadata.get("latent_cache_dir")
for image_key, img_md in images_metadata.items():
latent_cache_dir = img_md.get("latent_cache_dir", default_latent_cache_dir)
image_size = img_md.get("image_size")
# まず画像を優先して探す
if os.path.exists(image_key):
abs_path = image_key
# check if image_key is valid
is_valid = False
if dataset_metadata_utils.ARCHIVE_PATH_SEPARATOR in image_key:
if latent_cache_dir is not None:
latent_cache_path = latent_caching_strategy.get_latents_cache_path(image_key, image_size, latent_cache_dir)
is_valid = os.path.exists(latent_cache_path)
if not is_valid:
# if latent_cache_dir is not specified or cache file does not exist, we cannot check the existence of image
if not info_for_image_in_archive_is_shown:
logger.info(
f"image in archive is not checked for existence / アーカイブ内の画像は存在チェックができません。"
)
info_for_image_in_archive_is_shown = True
is_valid = True
else:
# わりといい加減だがいい方法が思いつかん
paths = glob_images(subset.image_dir, image_key)
if len(paths) > 0:
abs_path = paths[0]
# check existence of image file or cache file
is_valid = os.path.exists(image_key)
if not is_valid:
latent_cache_path = latent_caching_strategy.get_latents_cache_path(image_key, image_size, latent_cache_dir)
is_valid = os.path.exists(latent_cache_path)
# なければnpzを探す
if abs_path is None:
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
abs_path = os.path.splitext(image_key)[0] + ".npz"
else:
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
if os.path.exists(npz_path):
abs_path = npz_path
assert is_valid, f"no image / 画像がありません: {image_key}"
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
captions: Optional[list[str]] = img_md.get("caption")
caption_weights: Optional[list[float]] = img_md.get("caption_weights")
list_of_tags: Optional[list[str]] = img_md.get("tags")
tags_weights: Optional[list[float]] = img_md.get("tags_weights")
caption = img_md.get("caption")
tags = img_md.get("tags")
if caption is None:
caption = tags # could be multiline
tags = None
if subset.enable_wildcard:
# tags must be single line
if tags is not None:
tags = tags.replace("\n", subset.caption_separator)
# add tags to each line of caption
if caption is not None and tags is not None:
caption = "\n".join(
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
)
else:
# use as is
if tags is not None and len(tags) > 0:
caption = caption + subset.caption_separator + tags
tags_list.append(tags)
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info.image_size = img_md.get("train_resolution")
if not subset.color_aug and not subset.random_crop:
# if cache exists, use them
image_info.latents_cache_path, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
image_info = ImageInfo(image_key, subset.num_repeats, False, image_key)
image_info.set_fine_tuning_info(captions, caption_weights, list_of_tags, tags_weights, latent_cache_dir, image_size)
self.register_image(image_info, subset)
self.num_train_images += len(metadata) * subset.num_repeats
# TODO do not record tag freq when no tag
self.set_tag_frequency(os.path.basename(subset.metadata_file), tags_list)
if list_of_tags is not None:
self.set_tag_frequency(os.path.basename(subset.metadata_file), list_of_tags)
subset.img_count = len(metadata)
self.subsets.append(subset)
# check existence of all npz files
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
if use_npz_latents:
flip_aug_in_subset = False
npz_any = False
npz_all = True
for image_info in self.image_data.values():
subset = self.image_to_subset[image_info.image_key]
has_npz = image_info.latents_cache_path is not None
npz_any = npz_any or has_npz
if subset.flip_aug:
has_npz = has_npz and image_info.latents_npz_flipped is not None
flip_aug_in_subset = True
npz_all = npz_all and has_npz
if npz_any and not npz_all:
break
if not npz_any:
use_npz_latents = False
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
elif not npz_all:
use_npz_latents = False
logger.warning(
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
)
if flip_aug_in_subset:
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
# else:
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
# check min/max bucket size
sizes = set()
resos = set()
for image_info in self.image_data.values():
if image_info.image_size is None:
sizes = None # not calculated
break
sizes.add(image_info.image_size[0])
sizes.add(image_info.image_size[1])
resos.add(tuple(image_info.image_size))
if sizes is None:
if use_npz_latents:
use_npz_latents = False
logger.warning(
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
)
assert (
resolution is not None
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
self.enable_bucket = enable_bucket
if self.enable_bucket:
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
)
self.min_bucket_reso = min_bucket_reso
self.max_bucket_reso = max_bucket_reso
self.bucket_reso_steps = bucket_reso_steps
self.bucket_no_upscale = bucket_no_upscale
else:
if not enable_bucket:
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
self.enable_bucket = True
assert (
not bucket_no_upscale
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
# bucket情報を初期化しておく、make_bucketsで再作成しない
self.bucket_manager = BucketManager(False, None, None, None, None)
self.bucket_manager.set_predefined_resos(resos)
# npz情報をきれいにしておく
if not use_npz_latents:
for image_info in self.image_data.values():
image_info.latents_cache_path = image_info.latents_npz_flipped = None
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
base_name = os.path.splitext(image_key)[0]
npz_file_norm = base_name + ".npz"
if os.path.exists(npz_file_norm):
# image_key is full path
npz_file_flip = base_name + "_flip.npz"
if not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
# if not full path, check image_dir. if image_dir is None, return None
if subset.image_dir is None:
return None, None
# image_key is relative path
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
if not os.path.exists(npz_file_norm):
npz_file_norm = None
npz_file_flip = None
elif not os.path.exists(npz_file_flip):
npz_file_flip = None
return npz_file_norm, npz_file_flip
self.min_bucket_reso = None
self.max_bucket_reso = None
self.bucket_reso_steps = None # この情報は使われない
self.bucket_no_upscale = False
class ControlNetDataset(BaseDataset):

View File

@@ -22,17 +22,16 @@ def fire_in_thread(f, *args, **kwargs):
class ImageInfo:
def __init__(
self, image_key: str, num_repeats: int, captions: Optional[Union[str, list[str]]], is_reg: bool, absolute_path: str
) -> None:
def __init__(self, image_key: str, num_repeats: int, is_reg: bool, absolute_path: str) -> None:
self.image_key: str = image_key
self.num_repeats: int = num_repeats
self.captions: Optional[list[str]] = None if captions is None else ([captions] if isinstance(captions, str) else captions)
self.captions: Optional[list[str]] = None
self.caption_weights: Optional[list[float]] = None # weights for each caption in sampling
self.list_of_tags: Optional[list[str]] = None
self.tags_weights: Optional[list[float]] = None
self.is_reg: bool = is_reg
self.absolute_path: str = absolute_path
self.latents_cache_dir: Optional[str] = None
self.image_size: Tuple[int, int] = None
self.resized_size: Tuple[int, int] = None
self.bucket_reso: Tuple[int, int] = None
@@ -55,6 +54,28 @@ class ImageInfo:
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
def __str__(self) -> str:
return f"ImageInfo(image_key={self.image_key}, num_repeats={self.num_repeats}, captions={self.captions}, is_reg={self.is_reg}, absolute_path={self.absolute_path})"
def set_dreambooth_info(self, list_of_tags: list[str]) -> None:
self.list_of_tags = list_of_tags
def set_fine_tuning_info(
self,
captions: Optional[list[str]],
caption_weights: Optional[list[float]],
list_of_tags: Optional[list[str]],
tags_weights: Optional[list[float]],
image_size: Tuple[int, int],
latents_cache_dir: Optional[str],
):
self.captions = captions
self.caption_weights = caption_weights
self.list_of_tags = list_of_tags
self.tags_weights = tags_weights
self.image_size = image_size
self.latents_cache_dir = latents_cache_dir
# region Logging