mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
fix: metadata dataset degradation and make it work (#2186)
* fix: support dataset with metadata * feat: support another tagger model * fix: improve handling of image size and caption/tag processing in FineTuningDataset * fix: enhance metadata loading to support JSONL format in FineTuningDataset * feat: enhance image loading and processing in ImageLoadingPrepDataset with batch support and output options * fix: improve image path handling and memory management in dataset classes * Update finetune/tag_images_by_wd14_tagger.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: add return type annotation for process_tag_replacement function and ensure tags are returned * feat: add artist category threshold for tagging * doc: add comment for clarification --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -1131,7 +1131,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
other is not None
|
||||
and self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
@@ -1193,6 +1194,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
@@ -1205,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
# current_condition = None # keep current_condition to avoid next `clean_memory_on_device` call
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
@@ -1768,14 +1771,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
tensors = [converter(x) for x in tensors]
|
||||
if tensors[0].ndim == 1:
|
||||
# input_ids or mask
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, max_len - x.shape[0]))) for x in tensors]))
|
||||
else:
|
||||
# text encoder outputs
|
||||
result.append(
|
||||
torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors])
|
||||
)
|
||||
result.append(torch.stack([(torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[0]))) for x in tensors]))
|
||||
return result
|
||||
|
||||
# set example
|
||||
@@ -2202,6 +2201,23 @@ class FineTuningDataset(BaseDataset):
|
||||
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.latents_cache = None
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
self.min_bucket_reso = None
|
||||
self.max_bucket_reso = None
|
||||
self.bucket_reso_steps = None # この情報は使われない
|
||||
self.bucket_no_upscale = False
|
||||
|
||||
self.num_train_images = 0
|
||||
self.num_reg_images = 0
|
||||
@@ -2221,9 +2237,25 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
# メタデータを読み込む
|
||||
if os.path.exists(subset.metadata_file):
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
if subset.metadata_file.endswith(".jsonl"):
|
||||
logger.info(f"loading existing JSOL metadata: {subset.metadata_file}")
|
||||
# optional JSONL format
|
||||
# {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1", "image_size": [width, height]}
|
||||
metadata = {}
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line_md = json.loads(line)
|
||||
image_md = {"caption": line_md.get("caption", "")}
|
||||
if "image_size" in line_md:
|
||||
image_md["image_size"] = line_md["image_size"]
|
||||
if "tags" in line_md:
|
||||
image_md["tags"] = line_md["tags"]
|
||||
metadata[line_md["image_path"]] = image_md
|
||||
else:
|
||||
# standard JSON format
|
||||
logger.info(f"loading existing metadata: {subset.metadata_file}")
|
||||
with open(subset.metadata_file, "rt", encoding="utf-8") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
raise ValueError(f"no metadata / メタデータファイルがありません: {subset.metadata_file}")
|
||||
|
||||
@@ -2233,65 +2265,101 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
tags_list = []
|
||||
for image_key, img_md in metadata.items():
|
||||
# path情報を作る
|
||||
abs_path = None
|
||||
|
||||
# まず画像を優先して探す
|
||||
if os.path.exists(image_key):
|
||||
abs_path = image_key
|
||||
# Add full path for image
|
||||
image_dirs = set()
|
||||
if subset.image_dir is not None:
|
||||
image_dirs.add(subset.image_dir)
|
||||
for image_key in metadata.keys():
|
||||
if not os.path.isabs(image_key):
|
||||
assert (
|
||||
subset.image_dir is not None
|
||||
), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}"
|
||||
abs_path = os.path.join(subset.image_dir, image_key)
|
||||
else:
|
||||
# わりといい加減だがいい方法が思いつかん
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
abs_path = image_key
|
||||
image_dirs.add(os.path.dirname(abs_path))
|
||||
metadata[image_key]["abs_path"] = abs_path
|
||||
|
||||
# なければnpzを探す
|
||||
if abs_path is None:
|
||||
if os.path.exists(os.path.splitext(image_key)[0] + ".npz"):
|
||||
abs_path = os.path.splitext(image_key)[0] + ".npz"
|
||||
else:
|
||||
npz_path = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
if os.path.exists(npz_path):
|
||||
abs_path = npz_path
|
||||
# Enumerate existing npz files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
npz_paths = []
|
||||
for image_dir in image_dirs:
|
||||
npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix)))
|
||||
npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
# Match image filename longer to shorter because some images share same prefix
|
||||
image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True)
|
||||
|
||||
# Collect tags and sizes
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
image_size = img_md.get("image_size")
|
||||
abs_path = img_md.get("abs_path")
|
||||
|
||||
# search npz if image_size is not given
|
||||
npz_path = None
|
||||
if image_size is None:
|
||||
image_without_ext = os.path.splitext(image_key)[0]
|
||||
for candidate in npz_paths:
|
||||
if candidate.startswith(image_without_ext):
|
||||
npz_path = candidate
|
||||
break
|
||||
if npz_path is not None:
|
||||
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
|
||||
abs_path = npz_path
|
||||
|
||||
if caption is None:
|
||||
caption = tags # could be multiline
|
||||
tags = None
|
||||
caption = ""
|
||||
|
||||
if subset.enable_wildcard:
|
||||
# tags must be single line
|
||||
# tags must be single line (split by caption separator)
|
||||
if tags is not None:
|
||||
tags = tags.replace("\n", subset.caption_separator)
|
||||
|
||||
# add tags to each line of caption
|
||||
if caption is not None and tags is not None:
|
||||
if tags is not None:
|
||||
caption = "\n".join(
|
||||
[f"{line}{subset.caption_separator}{tags}" for line in caption.split("\n") if line.strip() != ""]
|
||||
)
|
||||
tags_list.append(tags)
|
||||
else:
|
||||
# use as is
|
||||
if tags is not None and len(tags) > 0:
|
||||
caption = caption + subset.caption_separator + tags
|
||||
if len(caption) > 0:
|
||||
caption = caption + subset.caption_separator
|
||||
caption = caption + tags
|
||||
tags_list.append(tags)
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
|
||||
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
|
||||
image_info.image_size = img_md.get("train_resolution")
|
||||
image_info.resize_interpolation = (
|
||||
subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
|
||||
)
|
||||
|
||||
if not subset.color_aug and not subset.random_crop:
|
||||
# if npz exists, use them
|
||||
image_info.latents_npz, image_info.latents_npz_flipped = self.image_key_to_npz_file(subset, image_key)
|
||||
if image_size is not None:
|
||||
image_info.image_size = tuple(image_size) # width, height
|
||||
size_set_from_metadata += 1
|
||||
elif npz_path is not None:
|
||||
# get image size from npz filename
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path)
|
||||
image_info.image_size = (w, h)
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_from_cache_filename > 0:
|
||||
logger.info(
|
||||
f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}"
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
@@ -2299,117 +2367,6 @@ class FineTuningDataset(BaseDataset):
|
||||
subset.img_count = len(metadata)
|
||||
self.subsets.append(subset)
|
||||
|
||||
# check existence of all npz files
|
||||
use_npz_latents = all([not (subset.color_aug or subset.random_crop) for subset in self.subsets])
|
||||
if use_npz_latents:
|
||||
flip_aug_in_subset = False
|
||||
npz_any = False
|
||||
npz_all = True
|
||||
|
||||
for image_info in self.image_data.values():
|
||||
subset = self.image_to_subset[image_info.image_key]
|
||||
|
||||
has_npz = image_info.latents_npz is not None
|
||||
npz_any = npz_any or has_npz
|
||||
|
||||
if subset.flip_aug:
|
||||
has_npz = has_npz and image_info.latents_npz_flipped is not None
|
||||
flip_aug_in_subset = True
|
||||
npz_all = npz_all and has_npz
|
||||
|
||||
if npz_any and not npz_all:
|
||||
break
|
||||
|
||||
if not npz_any:
|
||||
use_npz_latents = False
|
||||
logger.warning(f"npz file does not exist. ignore npz files / npzファイルが見つからないためnpzファイルを無視します")
|
||||
elif not npz_all:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"some of npz file does not exist. ignore npz files / いくつかのnpzファイルが見つからないためnpzファイルを無視します"
|
||||
)
|
||||
if flip_aug_in_subset:
|
||||
logger.warning("maybe no flipped files / 反転されたnpzファイルがないのかもしれません")
|
||||
# else:
|
||||
# logger.info("npz files are not used with color_aug and/or random_crop / color_augまたはrandom_cropが指定されているためnpzファイルは使用されません")
|
||||
|
||||
# check min/max bucket size
|
||||
sizes = set()
|
||||
resos = set()
|
||||
for image_info in self.image_data.values():
|
||||
if image_info.image_size is None:
|
||||
sizes = None # not calculated
|
||||
break
|
||||
sizes.add(image_info.image_size[0])
|
||||
sizes.add(image_info.image_size[1])
|
||||
resos.add(tuple(image_info.image_size))
|
||||
|
||||
if sizes is None:
|
||||
if use_npz_latents:
|
||||
use_npz_latents = False
|
||||
logger.warning(
|
||||
f"npz files exist, but no bucket info in metadata. ignore npz files / メタデータにbucket情報がないためnpzファイルを無視します"
|
||||
)
|
||||
|
||||
assert (
|
||||
resolution is not None
|
||||
), "if metadata doesn't have bucket info, resolution is required / メタデータにbucket情報がない場合はresolutionを指定してください"
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
self.bucket_no_upscale = bucket_no_upscale
|
||||
else:
|
||||
if not enable_bucket:
|
||||
logger.info("metadata has bucket info, enable bucketing / メタデータにbucket情報があるためbucketを有効にします")
|
||||
logger.info("using bucket info in metadata / メタデータ内のbucket情報を使います")
|
||||
self.enable_bucket = True
|
||||
|
||||
assert (
|
||||
not bucket_no_upscale
|
||||
), "if metadata has bucket info, bucket reso is precalculated, so bucket_no_upscale cannot be used / メタデータ内にbucket情報がある場合はbucketの解像度は計算済みのため、bucket_no_upscaleは使えません"
|
||||
|
||||
# bucket情報を初期化しておく、make_bucketsで再作成しない
|
||||
self.bucket_manager = BucketManager(False, None, None, None, None)
|
||||
self.bucket_manager.set_predefined_resos(resos)
|
||||
|
||||
# npz情報をきれいにしておく
|
||||
if not use_npz_latents:
|
||||
for image_info in self.image_data.values():
|
||||
image_info.latents_npz = image_info.latents_npz_flipped = None
|
||||
|
||||
def image_key_to_npz_file(self, subset: FineTuningSubset, image_key):
|
||||
base_name = os.path.splitext(image_key)[0]
|
||||
npz_file_norm = base_name + ".npz"
|
||||
|
||||
if os.path.exists(npz_file_norm):
|
||||
# image_key is full path
|
||||
npz_file_flip = base_name + "_flip.npz"
|
||||
if not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
# if not full path, check image_dir. if image_dir is None, return None
|
||||
if subset.image_dir is None:
|
||||
return None, None
|
||||
|
||||
# image_key is relative path
|
||||
npz_file_norm = os.path.join(subset.image_dir, image_key + ".npz")
|
||||
npz_file_flip = os.path.join(subset.image_dir, image_key + "_flip.npz")
|
||||
|
||||
if not os.path.exists(npz_file_norm):
|
||||
npz_file_norm = None
|
||||
npz_file_flip = None
|
||||
elif not os.path.exists(npz_file_flip):
|
||||
npz_file_flip = None
|
||||
|
||||
return npz_file_norm, npz_file_flip
|
||||
|
||||
|
||||
class ControlNetDataset(BaseDataset):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user