refactor caching latents (flip in same npz, etc)

This commit is contained in:
Kohya S
2023-07-15 18:28:33 +09:00
parent 81fa54837f
commit 94c151aea3
3 changed files with 409 additions and 239 deletions

View File

@@ -50,6 +50,7 @@ from diffusers import (
HeunDiscreteScheduler,
KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
AutoencoderKL,
)
from library import custom_train_functions
from library.original_unet import UNet2DConditionModel
@@ -96,6 +97,13 @@ try:
except:
pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -110,10 +118,10 @@ class ImageInfo:
self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_npz_flipped: str = None
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
self.cond_img_path: str = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
class BucketManager:
@@ -507,21 +515,22 @@ class BaseDataset(torch.utils.data.Dataset):
# augmentation
self.aug_helper = AugHelper()
self.image_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.image_transforms = IMAGE_TRANSFORMS
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
self.replacements = {}
# caching
self.caching_mode = None # None, 'latents', 'text'
def set_seed(self, seed):
self.seed = seed
def set_caching_mode(self, mode):
self.caching_mode = mode
def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets()
@@ -767,45 +776,6 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()
def load_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
def trim_and_resize_if_required(
self, subset: BaseSubset, image, reso, resized_size
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
crop_left_top = (0, 0)
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
crop_left_top = (p, 0)
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
crop_left_top = (crop_left_top[0], p)
assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image, original_size, crop_left_top
def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
@@ -822,26 +792,6 @@ class BaseDataset(torch.utils.data.Dataset):
]
)
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
for npath in [npz_path, flipped_npz_path]:
if npath is None:
continue
if not os.path.exists(npath):
return False
npz = np.load(npath)
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
return False
cached_latents = npz["latents"]
if cached_latents.shape[1:3] != expected_latents_size:
return False
return True
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# ちょっと速くした
print("caching latents.")
@@ -864,13 +814,10 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents
if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
if not is_main_process:
if not is_main_process: # store to info only
continue
cache_available = self.is_disk_cached_latents_is_expected(
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
)
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
if cache_available: # do not add to batch
continue
@@ -890,60 +837,19 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0:
batches.append(batch)
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return
# iterate batches
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = []
for info in batch:
image = self.load_image(info.absolute_path)
image, original_size, crop_left_top = self.trim_and_resize_if_required(
subset, image, info.bucket_reso, info.resized_size
)
image = self.image_transforms(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_left_top = crop_left_top
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
# check NaN
if torch.isnan(latents).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top)
else:
info.latents = latent
if subset.flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
# check NaN
if torch.isnan(latents).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
# crop_left_top is reversed when making batch
save_latents_to_disk(
info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top
)
else:
info.latents_flipped = latent
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
def get_image_size(self, image_path):
image = Image.open(image_path)
return image.size
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = self.load_image(image_path)
img = load_image(image_path)
face_cx = face_cy = face_w = face_h = 0
if subset.face_crop_aug_range is not None:
@@ -1004,10 +910,6 @@ class BaseDataset(torch.utils.data.Dataset):
return image
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
return load_latents_from_disk(npz_file)
def __len__(self):
return self._length
@@ -1016,6 +918,9 @@ class BaseDataset(torch.utils.data.Dataset):
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
if self.caching_mode is not None: # return batch for latents/text encoder outputs caching
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
loss_weights = []
captions = []
input_ids_list = []
@@ -1045,7 +950,10 @@ class BaseDataset(torch.utils.data.Dataset):
image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped)
latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz)
if flipped:
latents = flipped_latents
del flipped_latents
latents = torch.FloatTensor(latents)
image = None
@@ -1055,8 +963,8 @@ class BaseDataset(torch.utils.data.Dataset):
im_h, im_w = img.shape[0:2]
if self.enable_bucket:
img, original_size, crop_left_top = self.trim_and_resize_if_required(
subset, img, image_info.bucket_reso, image_info.resized_size
img, original_size, crop_left_top = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -1162,6 +1070,53 @@ class BaseDataset(torch.utils.data.Dataset):
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
captions = []
images = []
absolute_paths = []
resized_sizes = []
bucket_reso = None
flip_aug = None
random_crop = None
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
if flip_aug is None:
flip_aug = subset.flip_aug
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc.
if self.caching_mode == "latents":
image = load_image(image_info.absolute_path)
else:
image = None
captions.append(caption)
images.append(image)
absolute_paths.append(image_info.absolute_path)
resized_sizes.append(image_info.resized_size)
example = {}
if images[0] is None:
images = None
example["images"] = images
example["captions"] = captions
example["absolute_paths"] = absolute_paths
example["resized_sizes"] = resized_sizes
example["flip_aug"] = flip_aug
example["random_crop"] = random_crop
example["bucket_reso"] = bucket_reso
return example
class DreamBoothDataset(BaseDataset):
def __init__(
@@ -1635,11 +1590,7 @@ class ControlNetDataset(BaseDataset):
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
self.conditioning_image_transforms = transforms.Compose(
[
transforms.ToTensor(),
]
)
self.conditioning_image_transforms = IMAGE_TRANSFORMS
def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets()
@@ -1667,7 +1618,7 @@ class ControlNetDataset(BaseDataset):
original_size_hw = example["original_sizes_hw"][i]
crop_top_left = example["crop_top_lefts"][i]
flipped = example["flippeds"][i]
cond_img = self.load_image(image_info.cond_img_path)
cond_img = load_image(image_info.cond_img_path)
if self.dreambooth_dataset_delegate.enable_bucket:
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
@@ -1729,6 +1680,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -1752,28 +1707,53 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding()
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]:
if npz_path is None: # flipped doesn't exist
return None, None, None
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
if not os.path.exists(npz_path):
return False
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
return True
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
npz = np.load(npz_path)
if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_path}")
return None, None, None
return None, None, None, None
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_left_top = npz["crop_left_top"].tolist()
return latents, original_size, crop_left_top
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
return latents, original_size, crop_left_top, flipped_latents
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top):
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_left_top=np.array(crop_left_top),
**kwargs,
)
@@ -1948,6 +1928,93 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
return train_dataset_group
def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
crop_left_top = (0, 0)
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
crop_left_top = (p, 0)
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
crop_left_top = (crop_left_top[0], p)
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image, original_size, crop_left_top
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
optionally requires image_infos to have: image
if cache_to_disk is True, set info.latents_npz
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_left_top are also set
"""
images = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image = IMAGE_TRANSFORMS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_left_top = crop_left_top
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
# endregion
# region モジュール入れ替え部
@@ -3975,7 +4042,7 @@ def sample_images_common(
controlnet=controlnet,
controlnet_image=controlnet_image,
)
image = pipeline.latents_to_image(latents)[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())