mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
fix: improve image path handling and memory management in dataset classes
This commit is contained in:
@@ -1131,7 +1131,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.reso == other.reso
|
||||
other is not None
|
||||
and self.reso == other.reso
|
||||
and self.flip_aug == other.flip_aug
|
||||
and self.alpha_mask == other.alpha_mask
|
||||
and self.random_crop == other.random_crop
|
||||
@@ -1193,6 +1194,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0 and current_condition != condition:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if info.image is None:
|
||||
# load image in parallel
|
||||
@@ -1205,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) >= caching_strategy.batch_size:
|
||||
submit_batch(batch, current_condition)
|
||||
batch = []
|
||||
current_condition = None
|
||||
# current_condition = None
|
||||
|
||||
if len(batch) > 0:
|
||||
submit_batch(batch, current_condition)
|
||||
@@ -2234,49 +2237,53 @@ class FineTuningDataset(BaseDataset):
|
||||
)
|
||||
continue
|
||||
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix))
|
||||
npz_paths = [os.path.basename(x) for x in npz_paths]
|
||||
npz_paths = sorted(npz_paths, key=len, reverse=True) # make longer paths come first to speed up matching
|
||||
# Add full path for image
|
||||
image_dirs = set()
|
||||
if subset.image_dir is not None:
|
||||
image_dirs.add(subset.image_dir)
|
||||
for image_key in metadata.keys():
|
||||
if not os.path.isabs(image_key):
|
||||
assert (
|
||||
subset.image_dir is not None
|
||||
), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}"
|
||||
abs_path = os.path.join(subset.image_dir, image_key)
|
||||
else:
|
||||
abs_path = image_key
|
||||
image_dirs.add(os.path.dirname(abs_path))
|
||||
metadata[image_key]["abs_path"] = abs_path
|
||||
|
||||
tags_list = []
|
||||
# Enumerate existing npz files
|
||||
strategy = LatentsCachingStrategy.get_strategy()
|
||||
npz_paths = []
|
||||
for image_dir in image_dirs:
|
||||
npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix)))
|
||||
npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first
|
||||
|
||||
# Match image filename longer to shorter because some images share same prefix
|
||||
image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True)
|
||||
|
||||
size_set_count = 0
|
||||
# Collect tags and sizes
|
||||
tags_list = []
|
||||
size_set_from_metadata = 0
|
||||
size_set_from_cache_filename = 0
|
||||
for image_key in image_keys_sorted_by_length_desc:
|
||||
img_md = metadata[image_key]
|
||||
|
||||
# make absolute path for image or npz
|
||||
abs_path, npz_path = None, None
|
||||
|
||||
# full path for image?
|
||||
image_rel_key = image_key
|
||||
if os.path.exists(image_key):
|
||||
image_rel_key = os.path.basename(image_key)
|
||||
abs_path = image_key
|
||||
else:
|
||||
# relative path without extension
|
||||
paths = glob_images(subset.image_dir, image_key)
|
||||
if len(paths) > 0:
|
||||
abs_path = paths[0]
|
||||
|
||||
# search npz
|
||||
npz_path = None
|
||||
for candidate in npz_paths:
|
||||
if candidate.startswith(image_rel_key):
|
||||
npz_path = candidate
|
||||
break
|
||||
if npz_path is not None:
|
||||
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
|
||||
abs_path = abs_path or npz_path
|
||||
|
||||
assert abs_path is not None, f"no image / 画像がありません: {image_key}"
|
||||
|
||||
caption = img_md.get("caption")
|
||||
tags = img_md.get("tags")
|
||||
image_size = img_md.get("image_size")
|
||||
abs_path = img_md.get("abs_path")
|
||||
|
||||
# search npz if image_size is not given
|
||||
npz_path = None
|
||||
if image_size is None:
|
||||
image_without_ext = os.path.splitext(image_key)[0]
|
||||
for candidate in npz_paths:
|
||||
if candidate.startswith(image_without_ext):
|
||||
npz_path = candidate
|
||||
break
|
||||
if npz_path is not None:
|
||||
npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix)
|
||||
abs_path = npz_path
|
||||
|
||||
if caption is None:
|
||||
caption = ""
|
||||
@@ -2310,16 +2317,21 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
if image_size is not None:
|
||||
image_info.image_size = tuple(image_size) # width, height
|
||||
elif npz_path is not None and strategy is not None:
|
||||
size_set_from_metadata += 1
|
||||
elif npz_path is not None:
|
||||
# get image size from npz filename
|
||||
w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path)
|
||||
image_info.image_size = (w, h)
|
||||
size_set_count += 1
|
||||
size_set_from_cache_filename += 1
|
||||
|
||||
self.register_image(image_info, subset)
|
||||
|
||||
if size_set_count > 0:
|
||||
logger.info(f"set image size from cache files: {size_set_count}/{len(image_keys_sorted_by_length_desc)}")
|
||||
if size_set_from_cache_filename > 0:
|
||||
logger.info(
|
||||
f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}"
|
||||
)
|
||||
if size_set_from_metadata > 0:
|
||||
logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}")
|
||||
self.num_train_images += len(metadata) * subset.num_repeats
|
||||
|
||||
# TODO do not record tag freq when no tag
|
||||
|
||||
Reference in New Issue
Block a user