mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
feat: Support multi-resolution training with caching latents to disk
This commit is contained in:
11
README.md
11
README.md
@@ -9,13 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
|||||||
The command to install PyTorch is as follows:
|
The command to install PyTorch is as follows:
|
||||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||||
|
|
||||||
|
Aug 20, 2024 (update 3):
|
||||||
|
__Experimental__ The multi-resolution training is now supported with caching latents to disk.
|
||||||
|
|
||||||
|
The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file).
|
||||||
|
|
||||||
|
See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details.
|
||||||
|
|
||||||
Aug 20, 2024 (update 2):
|
Aug 20, 2024 (update 2):
|
||||||
`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015!
|
`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015!
|
||||||
|
|
||||||
Aug 20, 2024:
|
Aug 20, 2024:
|
||||||
FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution).
|
FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution).
|
||||||
|
|
||||||
The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details.
|
The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details.
|
||||||
|
|
||||||
We will support multi-resolution caching to disk in the near future.
|
We will support multi-resolution caching to disk in the near future.
|
||||||
|
|
||||||
@@ -171,7 +178,7 @@ The script can merge multiple LoRA models. If you want to merge multiple LoRA mo
|
|||||||
|
|
||||||
### FLUX.1 Multi-resolution training
|
### FLUX.1 Multi-resolution training
|
||||||
|
|
||||||
You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__
|
You can define multiple resolutions in the dataset configuration file.
|
||||||
|
|
||||||
The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution.
|
The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution.
|
||||||
|
|
||||||
|
|||||||
@@ -219,7 +219,13 @@ class LatentsCachingStrategy:
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _default_is_disk_cached_latents_expected(
|
def _default_is_disk_cached_latents_expected(
|
||||||
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
|
self,
|
||||||
|
latents_stride: int,
|
||||||
|
bucket_reso: Tuple[int, int],
|
||||||
|
npz_path: str,
|
||||||
|
flip_aug: bool,
|
||||||
|
alpha_mask: bool,
|
||||||
|
multi_resolution: bool = False,
|
||||||
):
|
):
|
||||||
if not self.cache_to_disk:
|
if not self.cache_to_disk:
|
||||||
return False
|
return False
|
||||||
@@ -230,25 +236,17 @@ class LatentsCachingStrategy:
|
|||||||
|
|
||||||
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||||
|
|
||||||
|
# e.g. "_32x64", HxW
|
||||||
|
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
npz = np.load(npz_path)
|
npz = np.load(npz_path)
|
||||||
if npz["latents"].shape[1:3] != expected_latents_size:
|
if "latents" + key_reso_suffix not in npz:
|
||||||
|
return False
|
||||||
|
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
|
||||||
|
return False
|
||||||
|
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if flip_aug:
|
|
||||||
if "latents_flipped" not in npz:
|
|
||||||
return False
|
|
||||||
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if alpha_mask:
|
|
||||||
if "alpha_mask" not in npz:
|
|
||||||
return False
|
|
||||||
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
if "alpha_mask" in npz:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading file: {npz_path}")
|
logger.error(f"Error loading file: {npz_path}")
|
||||||
raise e
|
raise e
|
||||||
@@ -257,7 +255,15 @@ class LatentsCachingStrategy:
|
|||||||
|
|
||||||
# TODO remove circular dependency for ImageInfo
|
# TODO remove circular dependency for ImageInfo
|
||||||
def _default_cache_batch_latents(
|
def _default_cache_batch_latents(
|
||||||
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
|
self,
|
||||||
|
encode_by_vae,
|
||||||
|
vae_device,
|
||||||
|
vae_dtype,
|
||||||
|
image_infos: List,
|
||||||
|
flip_aug: bool,
|
||||||
|
alpha_mask: bool,
|
||||||
|
random_crop: bool,
|
||||||
|
multi_resolution: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
|
||||||
@@ -287,8 +293,13 @@ class LatentsCachingStrategy:
|
|||||||
original_size = original_sizes[i]
|
original_size = original_sizes[i]
|
||||||
crop_ltrb = crop_ltrbs[i]
|
crop_ltrb = crop_ltrbs[i]
|
||||||
|
|
||||||
|
latents_size = latents.shape[1:3] # H, W
|
||||||
|
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
|
||||||
|
|
||||||
if self.cache_to_disk:
|
if self.cache_to_disk:
|
||||||
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
|
self.save_latents_to_disk(
|
||||||
|
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
info.latents_original_size = original_size
|
info.latents_original_size = original_size
|
||||||
info.latents_crop_ltrb = crop_ltrb
|
info.latents_crop_ltrb = crop_ltrb
|
||||||
@@ -298,31 +309,56 @@ class LatentsCachingStrategy:
|
|||||||
info.alpha_mask = alpha_mask
|
info.alpha_mask = alpha_mask
|
||||||
|
|
||||||
def load_latents_from_disk(
|
def load_latents_from_disk(
|
||||||
self, npz_path: str
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
npz = np.load(npz_path)
|
"""
|
||||||
if "latents" not in npz:
|
for SD/SDXL/SD3.0
|
||||||
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
|
"""
|
||||||
|
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
|
||||||
|
|
||||||
latents = npz["latents"]
|
def _default_load_latents_from_disk(
|
||||||
original_size = npz["original_size"].tolist()
|
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
crop_ltrb = npz["crop_ltrb"].tolist()
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
if latents_stride is None:
|
||||||
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
|
key_reso_suffix = ""
|
||||||
|
else:
|
||||||
|
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
|
||||||
|
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
|
||||||
|
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if "latents" + key_reso_suffix not in npz:
|
||||||
|
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
|
||||||
|
|
||||||
|
latents = npz["latents" + key_reso_suffix]
|
||||||
|
original_size = npz["original_size" + key_reso_suffix].tolist()
|
||||||
|
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
|
||||||
|
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
|
||||||
|
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
|
||||||
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
|
||||||
|
|
||||||
def save_latents_to_disk(
|
def save_latents_to_disk(
|
||||||
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
|
self,
|
||||||
|
npz_path,
|
||||||
|
latents_tensor,
|
||||||
|
original_size,
|
||||||
|
crop_ltrb,
|
||||||
|
flipped_latents_tensor=None,
|
||||||
|
alpha_mask=None,
|
||||||
|
key_reso_suffix="",
|
||||||
):
|
):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
|
if os.path.exists(npz_path):
|
||||||
|
# load existing npz and update it
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
for key in npz.files:
|
||||||
|
kwargs[key] = npz[key]
|
||||||
|
|
||||||
|
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
|
||||||
|
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
|
||||||
|
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
|
||||||
if flipped_latents_tensor is not None:
|
if flipped_latents_tensor is not None:
|
||||||
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
|
||||||
if alpha_mask is not None:
|
if alpha_mask is not None:
|
||||||
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
|
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
|
||||||
np.savez(
|
np.savez(npz_path, **kwargs)
|
||||||
npz_path,
|
|
||||||
latents=latents_tensor.float().cpu().numpy(),
|
|
||||||
original_size=np.array(original_size),
|
|
||||||
crop_ltrb=np.array(crop_ltrb),
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -200,7 +200,12 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
|
||||||
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
|
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True)
|
||||||
|
|
||||||
|
def load_latents_from_disk(
|
||||||
|
self, npz_path: str, bucket_reso: Tuple[int, int]
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
|
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
|
||||||
|
|
||||||
# TODO remove circular dependency for ImageInfo
|
# TODO remove circular dependency for ImageInfo
|
||||||
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
|
||||||
@@ -208,7 +213,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
|
|||||||
vae_device = vae.device
|
vae_device = vae.device
|
||||||
vae_dtype = vae.dtype
|
vae_dtype = vae.dtype
|
||||||
|
|
||||||
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
|
self._default_cache_batch_latents(
|
||||||
|
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True
|
||||||
|
)
|
||||||
|
|
||||||
if not train_util.HIGH_VRAM:
|
if not train_util.HIGH_VRAM:
|
||||||
train_util.clean_memory_on_device(vae.device)
|
train_util.clean_memory_on_device(vae.device)
|
||||||
|
|||||||
@@ -1381,7 +1381,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
image = None
|
image = None
|
||||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||||
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
|
||||||
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
|
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
|
||||||
)
|
)
|
||||||
if flipped:
|
if flipped:
|
||||||
latents = flipped_latents
|
latents = flipped_latents
|
||||||
|
|||||||
Reference in New Issue
Block a user