feat: Support multi-resolution training with caching latents to disk

This commit is contained in:
Kohya S
2024-08-20 21:39:43 +09:00
parent 388b3b4b74
commit 6ab48b09d8
4 changed files with 93 additions and 43 deletions

View File

@@ -9,13 +9,20 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
Aug 20, 2024 (update 3):
__Experimental__ The multi-resolution training is now supported with caching latents to disk.
The cache files now hold latents for multiple resolutions. Since the latents are appended to the current cache file, it is recommended to delete the cache file in advance (if not, the old latents is kept in .npz file).
See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details.
Aug 20, 2024 (update 2):
`flux_merge_lora.py` now supports LoRA from AI-toolkit (Diffusers based keys). Specify `--diffusers` option to merge LoRA with Diffusers based keys. Thanks to exveria1015!
Aug 20, 2024:
FLUX.1 supports multi-resolution inference, so training at multiple resolutions may be possible and the results may be improved (like 1024x1024, 768x768 and 512x512 ... you can use any resolution).
The script seems to support multi-resolution even in the current version, __if `--cache_latents_to_disk` is not specified__. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details.
The script seems to support multi-resolution even in the current version, ~~if `--cache_latents_to_disk` is not specified~~ -> `--cache_latents_to_disk` is now supported for multi-resolution training. Please try if you are interested. See [FLUX.1 Multi-resolution training](#flux1-multi-resolution-training) for details.
We will support multi-resolution caching to disk in the near future.
@@ -171,7 +178,7 @@ The script can merge multiple LoRA models. If you want to merge multiple LoRA mo
### FLUX.1 Multi-resolution training
You can define multiple resolutions in the dataset configuration file. __Caching latents to disk is not supported yet.__
You can define multiple resolutions in the dataset configuration file.
The dataset configuration file is like below. You can define multiple resolutions with different batch sizes. The resolutions are defined in the `[[datasets]]` section. The `[[datasets.subsets]]` section is for the dataset directory. Please specify the same directory for each resolution.

View File

@@ -219,7 +219,13 @@ class LatentsCachingStrategy:
raise NotImplementedError
def _default_is_disk_cached_latents_expected(
self, latents_stride: int, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool
self,
latents_stride: int,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
multi_resolution: bool = False,
):
if not self.cache_to_disk:
return False
@@ -230,24 +236,16 @@ class LatentsCachingStrategy:
expected_latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
# e.g. "_32x64", HxW
key_reso_suffix = f"_{expected_latents_size[0]}x{expected_latents_size[1]}" if multi_resolution else ""
try:
npz = np.load(npz_path)
if npz["latents"].shape[1:3] != expected_latents_size:
if "latents" + key_reso_suffix not in npz:
return False
if flip_aug:
if "latents_flipped" not in npz:
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
if alpha_mask:
if "alpha_mask" not in npz:
return False
if npz["alpha_mask"].shape[0:2] != (bucket_reso[1], bucket_reso[0]):
return False
else:
if "alpha_mask" in npz:
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -257,7 +255,15 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self, encode_by_vae, vae_device, vae_dtype, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool
self,
encode_by_vae,
vae_device,
vae_dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.
@@ -287,8 +293,13 @@ class LatentsCachingStrategy:
original_size = original_sizes[i]
crop_ltrb = crop_ltrbs[i]
latents_size = latents.shape[1:3] # H, W
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" if multi_resolution else "" # e.g. "_32x64", HxW
if self.cache_to_disk:
self.save_latents_to_disk(info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask)
self.save_latents_to_disk(
info.latents_npz, latents, original_size, crop_ltrb, flipped_latent, alpha_mask, key_reso_suffix
)
else:
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -298,31 +309,56 @@ class LatentsCachingStrategy:
info.alpha_mask = alpha_mask
def load_latents_from_disk(
self, npz_path: str
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
npz = np.load(npz_path)
if "latents" not in npz:
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
"""
for SD/SDXL/SD3.0
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_ltrb = npz["crop_ltrb"].tolist()
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None
def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
if latents_stride is None:
key_reso_suffix = ""
else:
latents_size = (bucket_reso[1] // latents_stride, bucket_reso[0] // latents_stride) # bucket_reso is (W, H)
key_reso_suffix = f"_{latents_size[0]}x{latents_size[1]}" # e.g. "_32x64", HxW
npz = np.load(npz_path)
if "latents" + key_reso_suffix not in npz:
raise ValueError(f"latents{key_reso_suffix} not found in {npz_path}")
latents = npz["latents" + key_reso_suffix]
original_size = npz["original_size" + key_reso_suffix].tolist()
crop_ltrb = npz["crop_ltrb" + key_reso_suffix].tolist()
flipped_latents = npz["latents_flipped" + key_reso_suffix] if "latents_flipped" + key_reso_suffix in npz else None
alpha_mask = npz["alpha_mask" + key_reso_suffix] if "alpha_mask" + key_reso_suffix in npz else None
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask
def save_latents_to_disk(
self, npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None
self,
npz_path,
latents_tensor,
original_size,
crop_ltrb,
flipped_latents_tensor=None,
alpha_mask=None,
key_reso_suffix="",
):
kwargs = {}
if os.path.exists(npz_path):
# load existing npz and update it
npz = np.load(npz_path)
for key in npz.files:
kwargs[key] = npz[key]
kwargs["latents" + key_reso_suffix] = latents_tensor.float().cpu().numpy()
kwargs["original_size" + key_reso_suffix] = np.array(original_size)
kwargs["crop_ltrb" + key_reso_suffix] = np.array(crop_ltrb)
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
kwargs["latents_flipped" + key_reso_suffix] = flipped_latents_tensor.float().cpu().numpy()
if alpha_mask is not None:
kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_ltrb=np.array(crop_ltrb),
**kwargs,
)
kwargs["alpha_mask" + key_reso_suffix] = alpha_mask.float().cpu().numpy()
np.savez(npz_path, **kwargs)

View File

@@ -200,7 +200,12 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
)
def is_disk_cached_latents_expected(self, bucket_reso: Tuple[int, int], npz_path: str, flip_aug: bool, alpha_mask: bool):
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask)
return self._default_is_disk_cached_latents_expected(8, bucket_reso, npz_path, flip_aug, alpha_mask, True)
def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
return self._default_load_latents_from_disk(8, npz_path, bucket_reso) # support multi-resolution
# TODO remove circular dependency for ImageInfo
def cache_batch_latents(self, vae, image_infos: List, flip_aug: bool, alpha_mask: bool, random_crop: bool):
@@ -208,7 +213,9 @@ class FluxLatentsCachingStrategy(LatentsCachingStrategy):
vae_device = vae.device
vae_dtype = vae.dtype
self._default_cache_batch_latents(encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop)
self._default_cache_batch_latents(
encode_by_vae, vae_device, vae_dtype, image_infos, flip_aug, alpha_mask, random_crop, True
)
if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(vae.device)

View File

@@ -1381,7 +1381,7 @@ class BaseDataset(torch.utils.data.Dataset):
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents, alpha_mask = (
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz)
self.latents_caching_strategy.load_latents_from_disk(image_info.latents_npz, image_info.bucket_reso)
)
if flipped:
latents = flipped_latents