refactor caching latents (flip in same npz, etc)

This commit is contained in:
Kohya S
2023-07-15 18:28:33 +09:00
parent 81fa54837f
commit 94c151aea3
3 changed files with 409 additions and 239 deletions

View File

@@ -34,22 +34,7 @@ def collate_fn_remove_corrupted(batch):
return batch return batch
def get_latents(vae, key_and_images, weight_dtype): def get_npz_filename(data_dir, image_key, is_full_path, recursive):
img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample()
# check NaN
for (key, _), latents1 in zip(key_and_images, latents):
if torch.isnan(latents1).any():
raise ValueError(f"NaN detected in latents of {key}")
return latents
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
if is_full_path: if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0] base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir) relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
@@ -57,13 +42,10 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
base_name = image_key base_name = image_key
relative_path = "" relative_path = ""
if flip:
base_name += "_flip"
if recursive and relative_path: if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name) return os.path.join(data_dir, relative_path, base_name) + ".npz"
else: else:
return os.path.join(data_dir, base_name) return os.path.join(data_dir, base_name) + ".npz"
def main(args): def main(args):
@@ -113,36 +95,7 @@ def main(args):
def process_batch(is_last): def process_batch(is_last):
for bucket in bucket_manager.buckets: for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype) train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
assert (
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
# flip
if args.flip_aug:
latents = get_latents(
vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype
) # copyがないとTensor変換できない
for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(
args.train_data_dir, image_key, args.full_path, True, args.recursive
)
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = (
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)
bucket.clear() bucket.clear()
# 読み込みの高速化のためにDataLoaderを使うオプション # 読み込みの高速化のためにDataLoaderを使うオプション
@@ -203,61 +156,18 @@ def main(args):
), f"internal error resized size is small: {resized_size}, {reso}" ), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshape等を確認して同じならskipする # 既に存在するファイルがあればshape等を確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
if args.skip_existing: if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
if args.flip_aug:
npz_files.append(
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break
latents, _, _ = train_util.load_latents_from_disk(npz_file)
if latents is None: # old version
found = False
break
if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
continue continue
# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
trim_left = 0
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
trim_left = trim_size // 2
trim_top = 0
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
trim_top = trim_size // 2
original_size_wh = (resized_size[0], resized_size[1])
# target_size_wh = (reso[0], reso[1])
crop_left_top = (trim_left, trim_top)
assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
# # debug
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
# バッチへ追加 # バッチへ追加
bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top)) image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_npz = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image
bucket_manager.add_image(reso, image_info)
# バッチを推論するか判定して推論する # バッチを推論するか判定して推論する
process_batch(False) process_batch(False)

View File

@@ -50,6 +50,7 @@ from diffusers import (
HeunDiscreteScheduler, HeunDiscreteScheduler,
KDPM2DiscreteScheduler, KDPM2DiscreteScheduler,
KDPM2AncestralDiscreteScheduler, KDPM2AncestralDiscreteScheduler,
AutoencoderKL,
) )
from library import custom_train_functions from library import custom_train_functions
from library.original_unet import UNet2DConditionModel from library.original_unet import UNet2DConditionModel
@@ -96,6 +97,13 @@ try:
except: except:
pass pass
IMAGE_TRANSFORMS = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
class ImageInfo: class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None: def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
@@ -110,10 +118,10 @@ class ImageInfo:
self.latents: torch.Tensor = None self.latents: torch.Tensor = None
self.latents_flipped: torch.Tensor = None self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None self.latents_npz: str = None
self.latents_npz_flipped: str = None
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
self.cond_img_path: str = None self.cond_img_path: str = None
self.image: Optional[Image.Image] = None # optional, original PIL Image
class BucketManager: class BucketManager:
@@ -507,21 +515,22 @@ class BaseDataset(torch.utils.data.Dataset):
# augmentation # augmentation
self.aug_helper = AugHelper() self.aug_helper = AugHelper()
self.image_transforms = transforms.Compose( self.image_transforms = IMAGE_TRANSFORMS
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.image_data: Dict[str, ImageInfo] = {} self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {} self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
self.replacements = {} self.replacements = {}
# caching
self.caching_mode = None # None, 'latents', 'text'
def set_seed(self, seed): def set_seed(self, seed):
self.seed = seed self.seed = seed
def set_caching_mode(self, mode):
self.caching_mode = mode
def set_current_epoch(self, epoch): def set_current_epoch(self, epoch):
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
self.shuffle_buckets() self.shuffle_buckets()
@@ -767,45 +776,6 @@ class BaseDataset(torch.utils.data.Dataset):
random.shuffle(self.buckets_indices) random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle() self.bucket_manager.shuffle()
def load_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
def trim_and_resize_if_required(
self, subset: BaseSubset, image, reso, resized_size
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
crop_left_top = (0, 0)
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
crop_left_top = (p, 0)
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
crop_left_top = (crop_left_top[0], p)
assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image, original_size, crop_left_top
def is_latent_cacheable(self): def is_latent_cacheable(self):
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets]) return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
@@ -822,26 +792,6 @@ class BaseDataset(torch.utils.data.Dataset):
] ]
) )
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
for npath in [npz_path, flipped_npz_path]:
if npath is None:
continue
if not os.path.exists(npath):
return False
npz = np.load(npath)
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
return False
cached_latents = npz["latents"]
if cached_latents.shape[1:3] != expected_latents_size:
return False
return True
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
# ちょっと速くした # ちょっと速くした
print("caching latents.") print("caching latents.")
@@ -864,13 +814,10 @@ class BaseDataset(torch.utils.data.Dataset):
# check disk cache exists and size of latents # check disk cache exists and size of latents
if cache_to_disk: if cache_to_disk:
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz" info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz" if not is_main_process: # store to info only
if not is_main_process:
continue continue
cache_available = self.is_disk_cached_latents_is_expected( cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
)
if cache_available: # do not add to batch if cache_available: # do not add to batch
continue continue
@@ -890,60 +837,19 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0: if len(batch) > 0:
batches.append(batch) batches.append(batch)
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
return return
# iterate batches # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
for batch in tqdm(batches, smoothing=1, total=len(batches)): for batch in tqdm(batches, smoothing=1, total=len(batches)):
images = [] cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
for info in batch:
image = self.load_image(info.absolute_path)
image, original_size, crop_left_top = self.trim_and_resize_if_required(
subset, image, info.bucket_reso, info.resized_size
)
image = self.image_transforms(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_left_top = crop_left_top
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
# check NaN
if torch.isnan(latents).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top)
else:
info.latents = latent
if subset.flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
# check NaN
if torch.isnan(latents).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
# crop_left_top is reversed when making batch
save_latents_to_disk(
info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top
)
else:
info.latents_flipped = latent
def get_image_size(self, image_path): def get_image_size(self, image_path):
image = Image.open(image_path) image = Image.open(image_path)
return image.size return image.size
def load_image_with_face_info(self, subset: BaseSubset, image_path: str): def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
img = self.load_image(image_path) img = load_image(image_path)
face_cx = face_cy = face_w = face_h = 0 face_cx = face_cy = face_w = face_h = 0
if subset.face_crop_aug_range is not None: if subset.face_crop_aug_range is not None:
@@ -1004,10 +910,6 @@ class BaseDataset(torch.utils.data.Dataset):
return image return image
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
return load_latents_from_disk(npz_file)
def __len__(self): def __len__(self):
return self._length return self._length
@@ -1016,6 +918,9 @@ class BaseDataset(torch.utils.data.Dataset):
bucket_batch_size = self.buckets_indices[index].bucket_batch_size bucket_batch_size = self.buckets_indices[index].bucket_batch_size
image_index = self.buckets_indices[index].batch_index * bucket_batch_size image_index = self.buckets_indices[index].batch_index * bucket_batch_size
if self.caching_mode is not None: # return batch for latents/text encoder outputs caching
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
loss_weights = [] loss_weights = []
captions = [] captions = []
input_ids_list = [] input_ids_list = []
@@ -1045,7 +950,10 @@ class BaseDataset(torch.utils.data.Dataset):
image = None image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped) latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz)
if flipped:
latents = flipped_latents
del flipped_latents
latents = torch.FloatTensor(latents) latents = torch.FloatTensor(latents)
image = None image = None
@@ -1055,8 +963,8 @@ class BaseDataset(torch.utils.data.Dataset):
im_h, im_w = img.shape[0:2] im_h, im_w = img.shape[0:2]
if self.enable_bucket: if self.enable_bucket:
img, original_size, crop_left_top = self.trim_and_resize_if_required( img, original_size, crop_left_top = trim_and_resize_if_required(
subset, img, image_info.bucket_reso, image_info.resized_size subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
) )
else: else:
if face_cx > 0: # 顔位置情報あり if face_cx > 0: # 顔位置情報あり
@@ -1162,6 +1070,53 @@ class BaseDataset(torch.utils.data.Dataset):
example["image_keys"] = bucket[image_index : image_index + self.batch_size] example["image_keys"] = bucket[image_index : image_index + self.batch_size]
return example return example
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
captions = []
images = []
absolute_paths = []
resized_sizes = []
bucket_reso = None
flip_aug = None
random_crop = None
for image_key in bucket[image_index : image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
subset = self.image_to_subset[image_key]
if flip_aug is None:
flip_aug = subset.flip_aug
random_crop = subset.random_crop
bucket_reso = image_info.bucket_reso
else:
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc.
if self.caching_mode == "latents":
image = load_image(image_info.absolute_path)
else:
image = None
captions.append(caption)
images.append(image)
absolute_paths.append(image_info.absolute_path)
resized_sizes.append(image_info.resized_size)
example = {}
if images[0] is None:
images = None
example["images"] = images
example["captions"] = captions
example["absolute_paths"] = absolute_paths
example["resized_sizes"] = resized_sizes
example["flip_aug"] = flip_aug
example["random_crop"] = random_crop
example["bucket_reso"] = bucket_reso
return example
class DreamBoothDataset(BaseDataset): class DreamBoothDataset(BaseDataset):
def __init__( def __init__(
@@ -1635,11 +1590,7 @@ class ControlNetDataset(BaseDataset):
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}" assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}" assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
self.conditioning_image_transforms = transforms.Compose( self.conditioning_image_transforms = IMAGE_TRANSFORMS
[
transforms.ToTensor(),
]
)
def make_buckets(self): def make_buckets(self):
self.dreambooth_dataset_delegate.make_buckets() self.dreambooth_dataset_delegate.make_buckets()
@@ -1667,7 +1618,7 @@ class ControlNetDataset(BaseDataset):
original_size_hw = example["original_sizes_hw"][i] original_size_hw = example["original_sizes_hw"][i]
crop_top_left = example["crop_top_lefts"][i] crop_top_left = example["crop_top_lefts"][i]
flipped = example["flippeds"][i] flipped = example["flippeds"][i]
cond_img = self.load_image(image_info.cond_img_path) cond_img = load_image(image_info.cond_img_path)
if self.dreambooth_dataset_delegate.enable_bucket: if self.dreambooth_dataset_delegate.enable_bucket:
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
@@ -1729,6 +1680,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
print(f"[Dataset {i}]") print(f"[Dataset {i}]")
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
def set_caching_mode(self, caching_mode):
for dataset in self.datasets:
dataset.set_caching_mode(caching_mode)
def is_latent_cacheable(self) -> bool: def is_latent_cacheable(self) -> bool:
return all([dataset.is_latent_cacheable() for dataset in self.datasets]) return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -1752,28 +1707,53 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding() dataset.disable_token_padding()
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]: expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
if npz_path is None: # flipped doesn't exist
return None, None, None
if not os.path.exists(npz_path):
return False
npz = np.load(npz_path)
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
return False
if npz["latents"].shape[1:3] != expected_latents_size:
return False
if flip_aug:
if "latents_flipped" not in npz:
return False
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
return False
return True
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(
npz_path,
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
npz = np.load(npz_path) npz = np.load(npz_path)
if "latents" not in npz: if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_path}") print(f"error: npz is old format. please re-generate {npz_path}")
return None, None, None return None, None, None, None
latents = npz["latents"] latents = npz["latents"]
original_size = npz["original_size"].tolist() original_size = npz["original_size"].tolist()
crop_left_top = npz["crop_left_top"].tolist() crop_left_top = npz["crop_left_top"].tolist()
return latents, original_size, crop_left_top flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
return latents, original_size, crop_left_top, flipped_latents
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top): def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None):
kwargs = {}
if flipped_latents_tensor is not None:
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
np.savez( np.savez(
npz_path, npz_path,
latents=latents_tensor.float().cpu().numpy(), latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size), original_size=np.array(original_size),
crop_left_top=np.array(crop_left_top), crop_left_top=np.array(crop_left_top),
**kwargs,
) )
@@ -1948,6 +1928,93 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
return train_dataset_group return train_dataset_group
def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
img = np.array(image, np.uint8)
return img
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height)
crop_left_top = (0, 0)
if image_width > reso[0]:
trim_size = image_width - reso[0]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
# print("w", trim_size, p)
image = image[:, p : p + reso[0]]
crop_left_top = (p, 0)
if image_height > reso[1]:
trim_size = image_height - reso[1]
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
# print("h", trim_size, p)
image = image[p : p + reso[1]]
crop_left_top = (crop_left_top[0], p)
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
return image, original_size, crop_left_top
def cache_batch_latents(
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
) -> None:
r"""
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
optionally requires image_infos to have: image
if cache_to_disk is True, set info.latents_npz
flipped latents is also saved if flip_aug is True
if cache_to_disk is False, set info.latents
latents_flipped is also set if flip_aug is True
latents_original_size and latents_crop_left_top are also set
"""
images = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image = IMAGE_TRANSFORMS(image)
images.append(image)
info.latents_original_size = original_size
info.latents_crop_left_top = crop_left_top
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
if flip_aug:
img_tensors = torch.flip(img_tensors, dims=[3])
with torch.no_grad():
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
else:
flipped_latents = [None] * len(latents)
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk:
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent)
else:
info.latents = latent
if flip_aug:
info.latents_flipped = flipped_latent
# endregion # endregion
# region モジュール入れ替え部 # region モジュール入れ替え部
@@ -3975,7 +4042,7 @@ def sample_images_common(
controlnet=controlnet, controlnet=controlnet,
controlnet_image=controlnet_image, controlnet_image=controlnet_image,
) )
image = pipeline.latents_to_image(latents)[0] image = pipeline.latents_to_image(latents)[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())

193
tools/cache_latents.py Normal file
View File

@@ -0,0 +1,193 @@
# latentsのdiskへの事前キャッシュを行う / cache latents to disk
import argparse
import math
from multiprocessing import Value
import os
from accelerate.utils import set_seed
import torch
from tqdm import tqdm
from library import config_util
from library import train_util
from library import sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
def cache_to_disk(args: argparse.Namespace) -> None:
train_util.prepare_dataset_args(args, True)
# check cache latents arg
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
use_dreambooth_method = args.in_json is None
if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する
# tokenizerを準備するdatasetを動かすために必要
if args.sdxl:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizers = [tokenizer1, tokenizer2]
else:
tokenizer = train_util.load_tokenizer(args)
tokenizers = [tokenizer]
# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if args.dataset_config is not None:
print(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
print("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
# datasetのcache_latentsを呼ばなければ、生の画像が返る
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
# acceleratorを準備する
print("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
print("load model")
if args.sdxl:
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
else:
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
vae.set_use_memory_efficient_attention_xformers(args.xformers)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
# dataloaderを準備する
train_dataset_group.set_caching_mode("latents")
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# acceleratorを使ってモデルを準備するマルチGPUで使えるようになるはず
train_dataloader = accelerator.prepare(train_dataloader)
# データ取得のためのループ
for batch in tqdm(train_dataloader):
b_size = len(batch["images"])
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
flip_aug = batch["flip_aug"]
random_crop = batch["random_crop"]
bucket_reso = batch["bucket_reso"]
# バッチを分割して処理する
for i in range(0, b_size, vae_batch_size):
images = batch["images"][i : i + vae_batch_size]
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
image_infos = []
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
image_info.image = image
image_info.bucket_reso = bucket_reso
image_info.resized_size = resized_size
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
if args.skip_existing:
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
print(f"Skipping {image_info.latents_npz} because it already exists.")
continue
image_infos.append(image_info)
if len(image_infos) > 0:
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
accelerator.wait_for_everyone()
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
train_util.add_sd_models_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
config_util.add_config_arguments(parser)
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--skip_existing",
action="store_true",
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップするflip_aug有効時は通常、反転の両方が存在する画像をスキップ",
)
return parser
if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
cache_to_disk(args)