From f3d5b063376ea554eb0f8a21977990a43eb6dedc Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Wed, 3 Sep 2025 22:00:20 +0900 Subject: [PATCH] fix: improve image path handling and memory management in dataset classes --- library/train_util.py | 90 ++++++++++++++++++++++++------------------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 29b61bf3..131cb612 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1131,7 +1131,8 @@ class BaseDataset(torch.utils.data.Dataset): def __eq__(self, other): return ( - self.reso == other.reso + other is not None + and self.reso == other.reso and self.flip_aug == other.flip_aug and self.alpha_mask == other.alpha_mask and self.random_crop == other.random_crop @@ -1193,6 +1194,8 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) > 0 and current_condition != condition: submit_batch(batch, current_condition) batch = [] + if condition != current_condition and HIGH_VRAM: # even with high VRAM, if shape is changed + clean_memory_on_device(accelerator.device) if info.image is None: # load image in parallel @@ -1205,7 +1208,7 @@ class BaseDataset(torch.utils.data.Dataset): if len(batch) >= caching_strategy.batch_size: submit_batch(batch, current_condition) batch = [] - current_condition = None + # current_condition = None if len(batch) > 0: submit_batch(batch, current_condition) @@ -2234,49 +2237,53 @@ class FineTuningDataset(BaseDataset): ) continue - strategy = LatentsCachingStrategy.get_strategy() - npz_paths = glob.glob(os.path.join(subset.image_dir, "*" + strategy.cache_suffix)) - npz_paths = [os.path.basename(x) for x in npz_paths] - npz_paths = sorted(npz_paths, key=len, reverse=True) # make longer paths come first to speed up matching + # Add full path for image + image_dirs = set() + if subset.image_dir is not None: + image_dirs.add(subset.image_dir) + for image_key in metadata.keys(): + if not os.path.isabs(image_key): + assert ( + subset.image_dir is not None + ), f"image_dir is required when image paths are relative / 画像パスが相対パスの場合、image_dirの指定が必要です: {image_key}" + abs_path = os.path.join(subset.image_dir, image_key) + else: + abs_path = image_key + image_dirs.add(os.path.dirname(abs_path)) + metadata[image_key]["abs_path"] = abs_path - tags_list = [] + # Enumerate existing npz files + strategy = LatentsCachingStrategy.get_strategy() + npz_paths = [] + for image_dir in image_dirs: + npz_paths.extend(glob.glob(os.path.join(image_dir, "*" + strategy.cache_suffix))) + npz_paths = sorted(npz_paths, key=lambda item: len(os.path.basename(item)), reverse=True) # longer paths first # Match image filename longer to shorter because some images share same prefix image_keys_sorted_by_length_desc = sorted(metadata.keys(), key=len, reverse=True) - size_set_count = 0 + # Collect tags and sizes + tags_list = [] + size_set_from_metadata = 0 + size_set_from_cache_filename = 0 for image_key in image_keys_sorted_by_length_desc: img_md = metadata[image_key] - - # make absolute path for image or npz - abs_path, npz_path = None, None - - # full path for image? - image_rel_key = image_key - if os.path.exists(image_key): - image_rel_key = os.path.basename(image_key) - abs_path = image_key - else: - # relative path without extension - paths = glob_images(subset.image_dir, image_key) - if len(paths) > 0: - abs_path = paths[0] - - # search npz - npz_path = None - for candidate in npz_paths: - if candidate.startswith(image_rel_key): - npz_path = candidate - break - if npz_path is not None: - npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix) - abs_path = abs_path or npz_path - - assert abs_path is not None, f"no image / 画像がありません: {image_key}" - caption = img_md.get("caption") tags = img_md.get("tags") image_size = img_md.get("image_size") + abs_path = img_md.get("abs_path") + + # search npz if image_size is not given + npz_path = None + if image_size is None: + image_without_ext = os.path.splitext(image_key)[0] + for candidate in npz_paths: + if candidate.startswith(image_without_ext): + npz_path = candidate + break + if npz_path is not None: + npz_paths.remove(npz_path) # remove to avoid matching same file (share prefix) + abs_path = npz_path if caption is None: caption = "" @@ -2310,16 +2317,21 @@ class FineTuningDataset(BaseDataset): if image_size is not None: image_info.image_size = tuple(image_size) # width, height - elif npz_path is not None and strategy is not None: + size_set_from_metadata += 1 + elif npz_path is not None: # get image size from npz filename w, h = strategy.get_image_size_from_disk_cache_path(abs_path, npz_path) image_info.image_size = (w, h) - size_set_count += 1 + size_set_from_cache_filename += 1 self.register_image(image_info, subset) - if size_set_count > 0: - logger.info(f"set image size from cache files: {size_set_count}/{len(image_keys_sorted_by_length_desc)}") + if size_set_from_cache_filename > 0: + logger.info( + f"set image size from cache files: {size_set_from_cache_filename}/{len(image_keys_sorted_by_length_desc)}" + ) + if size_set_from_metadata > 0: + logger.info(f"set image size from metadata: {size_set_from_metadata}/{len(image_keys_sorted_by_length_desc)}") self.num_train_images += len(metadata) * subset.num_repeats # TODO do not record tag freq when no tag