From ab0d534bc53cf837ee167624bdddc61b46d1e702 Mon Sep 17 00:00:00 2001 From: woctordho Date: Mon, 16 Feb 2026 01:58:15 +0800 Subject: [PATCH] Multi-resolution dataset for SD1/SDXL --- library/strategy_sd.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/library/strategy_sd.py b/library/strategy_sd.py index a44fc409..48808db1 100644 --- a/library/strategy_sd.py +++ b/library/strategy_sd.py @@ -2,6 +2,7 @@ import glob import os from typing import Any, List, Optional, Tuple, Union +import numpy as np import torch from transformers import CLIPTokenizer from library import train_util @@ -157,7 +158,12 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): return os.path.splitext(absolute_path)[0] + f"_{image_size[0]:04d}x{image_size[1]:04d}" + self.suffix def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool): - return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask) + return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True) + + def load_latents_from_disk( + self, npz_path: str, bucket_reso: Tuple[int, int] + ) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: + return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # TODO remove circular dependency for ImageInfo def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool): @@ -165,7 +171,7 @@ class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy): vae_device = vae.device vae_dtype = vae.dtype - self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop) + self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, multi_resolution=True) if not train_util.HIGH_VRAM: train_util.clean_memory_on_device(vae.device)