mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
refactor caching latents (flip in same npz, etc)
This commit is contained in:
@@ -34,22 +34,7 @@ def collate_fn_remove_corrupted(batch):
|
|||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
|
||||||
def get_latents(vae, key_and_images, weight_dtype):
|
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||||
img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images]
|
|
||||||
img_tensors = torch.stack(img_tensors)
|
|
||||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
|
||||||
with torch.no_grad():
|
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample()
|
|
||||||
|
|
||||||
# check NaN
|
|
||||||
for (key, _), latents1 in zip(key_and_images, latents):
|
|
||||||
if torch.isnan(latents1).any():
|
|
||||||
raise ValueError(f"NaN detected in latents of {key}")
|
|
||||||
|
|
||||||
return latents
|
|
||||||
|
|
||||||
|
|
||||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
|
||||||
if is_full_path:
|
if is_full_path:
|
||||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||||
@@ -57,13 +42,10 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
|||||||
base_name = image_key
|
base_name = image_key
|
||||||
relative_path = ""
|
relative_path = ""
|
||||||
|
|
||||||
if flip:
|
|
||||||
base_name += "_flip"
|
|
||||||
|
|
||||||
if recursive and relative_path:
|
if recursive and relative_path:
|
||||||
return os.path.join(data_dir, relative_path, base_name)
|
return os.path.join(data_dir, relative_path, base_name) + ".npz"
|
||||||
else:
|
else:
|
||||||
return os.path.join(data_dir, base_name)
|
return os.path.join(data_dir, base_name) + ".npz"
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@@ -113,36 +95,7 @@ def main(args):
|
|||||||
def process_batch(is_last):
|
def process_batch(is_last):
|
||||||
for bucket in bucket_manager.buckets:
|
for bucket in bucket_manager.buckets:
|
||||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||||
latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype)
|
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
|
||||||
assert (
|
|
||||||
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
|
|
||||||
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
|
||||||
|
|
||||||
for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
|
|
||||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
|
|
||||||
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
|
|
||||||
|
|
||||||
# flip
|
|
||||||
if args.flip_aug:
|
|
||||||
latents = get_latents(
|
|
||||||
vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype
|
|
||||||
) # copyがないとTensor変換できない
|
|
||||||
|
|
||||||
for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
|
|
||||||
npz_file_name = get_npz_filename_wo_ext(
|
|
||||||
args.train_data_dir, image_key, args.full_path, True, args.recursive
|
|
||||||
)
|
|
||||||
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
|
|
||||||
else:
|
|
||||||
# remove existing flipped npz
|
|
||||||
for image_key, _ in bucket:
|
|
||||||
npz_file_name = (
|
|
||||||
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
|
||||||
)
|
|
||||||
if os.path.isfile(npz_file_name):
|
|
||||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
|
||||||
os.remove(npz_file_name)
|
|
||||||
|
|
||||||
bucket.clear()
|
bucket.clear()
|
||||||
|
|
||||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||||
@@ -203,61 +156,18 @@ def main(args):
|
|||||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||||
|
|
||||||
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
||||||
|
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
|
||||||
if args.skip_existing:
|
if args.skip_existing:
|
||||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
|
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
|
||||||
if args.flip_aug:
|
|
||||||
npz_files.append(
|
|
||||||
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
|
||||||
)
|
|
||||||
|
|
||||||
found = True
|
|
||||||
for npz_file in npz_files:
|
|
||||||
if not os.path.exists(npz_file):
|
|
||||||
found = False
|
|
||||||
break
|
|
||||||
|
|
||||||
latents, _, _ = train_util.load_latents_from_disk(npz_file)
|
|
||||||
if latents is None: # old version
|
|
||||||
found = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
|
||||||
found = False
|
|
||||||
break
|
|
||||||
if found:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 画像をリサイズしてトリミングする
|
|
||||||
# PILにinter_areaがないのでcv2で……
|
|
||||||
image = np.array(image)
|
|
||||||
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
|
||||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
|
||||||
|
|
||||||
trim_left = 0
|
|
||||||
if resized_size[0] > reso[0]:
|
|
||||||
trim_size = resized_size[0] - reso[0]
|
|
||||||
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
|
|
||||||
trim_left = trim_size // 2
|
|
||||||
|
|
||||||
trim_top = 0
|
|
||||||
if resized_size[1] > reso[1]:
|
|
||||||
trim_size = resized_size[1] - reso[1]
|
|
||||||
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
|
|
||||||
trim_top = trim_size // 2
|
|
||||||
|
|
||||||
original_size_wh = (resized_size[0], resized_size[1])
|
|
||||||
# target_size_wh = (reso[0], reso[1])
|
|
||||||
crop_left_top = (trim_left, trim_top)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
|
||||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
|
||||||
|
|
||||||
# # debug
|
|
||||||
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
|
||||||
|
|
||||||
# バッチへ追加
|
# バッチへ追加
|
||||||
bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top))
|
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
|
||||||
|
image_info.latents_npz = npz_file_name
|
||||||
|
image_info.bucket_reso = reso
|
||||||
|
image_info.resized_size = resized_size
|
||||||
|
image_info.image = image
|
||||||
|
bucket_manager.add_image(reso, image_info)
|
||||||
|
|
||||||
# バッチを推論するか判定して推論する
|
# バッチを推論するか判定して推論する
|
||||||
process_batch(False)
|
process_batch(False)
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from diffusers import (
|
|||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
KDPM2DiscreteScheduler,
|
KDPM2DiscreteScheduler,
|
||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
|
AutoencoderKL,
|
||||||
)
|
)
|
||||||
from library import custom_train_functions
|
from library import custom_train_functions
|
||||||
from library.original_unet import UNet2DConditionModel
|
from library.original_unet import UNet2DConditionModel
|
||||||
@@ -96,6 +97,13 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
IMAGE_TRANSFORMS = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImageInfo:
|
class ImageInfo:
|
||||||
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
|
||||||
@@ -110,10 +118,10 @@ class ImageInfo:
|
|||||||
self.latents: torch.Tensor = None
|
self.latents: torch.Tensor = None
|
||||||
self.latents_flipped: torch.Tensor = None
|
self.latents_flipped: torch.Tensor = None
|
||||||
self.latents_npz: str = None
|
self.latents_npz: str = None
|
||||||
self.latents_npz_flipped: str = None
|
|
||||||
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
self.latents_original_size: Tuple[int, int] = None # original image size, not latents size
|
||||||
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
self.latents_crop_left_top: Tuple[int, int] = None # original image crop left top, not latents crop left top
|
||||||
self.cond_img_path: str = None
|
self.cond_img_path: str = None
|
||||||
|
self.image: Optional[Image.Image] = None # optional, original PIL Image
|
||||||
|
|
||||||
|
|
||||||
class BucketManager:
|
class BucketManager:
|
||||||
@@ -507,21 +515,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
# augmentation
|
# augmentation
|
||||||
self.aug_helper = AugHelper()
|
self.aug_helper = AugHelper()
|
||||||
|
|
||||||
self.image_transforms = transforms.Compose(
|
self.image_transforms = IMAGE_TRANSFORMS
|
||||||
[
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.image_data: Dict[str, ImageInfo] = {}
|
self.image_data: Dict[str, ImageInfo] = {}
|
||||||
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
|
||||||
|
|
||||||
self.replacements = {}
|
self.replacements = {}
|
||||||
|
|
||||||
|
# caching
|
||||||
|
self.caching_mode = None # None, 'latents', 'text'
|
||||||
|
|
||||||
def set_seed(self, seed):
|
def set_seed(self, seed):
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
|
||||||
|
def set_caching_mode(self, mode):
|
||||||
|
self.caching_mode = mode
|
||||||
|
|
||||||
def set_current_epoch(self, epoch):
|
def set_current_epoch(self, epoch):
|
||||||
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする
|
||||||
self.shuffle_buckets()
|
self.shuffle_buckets()
|
||||||
@@ -767,45 +776,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
random.shuffle(self.buckets_indices)
|
random.shuffle(self.buckets_indices)
|
||||||
self.bucket_manager.shuffle()
|
self.bucket_manager.shuffle()
|
||||||
|
|
||||||
def load_image(self, image_path):
|
|
||||||
image = Image.open(image_path)
|
|
||||||
if not image.mode == "RGB":
|
|
||||||
image = image.convert("RGB")
|
|
||||||
img = np.array(image, np.uint8)
|
|
||||||
return img
|
|
||||||
|
|
||||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
|
|
||||||
def trim_and_resize_if_required(
|
|
||||||
self, subset: BaseSubset, image, reso, resized_size
|
|
||||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
|
|
||||||
image_height, image_width = image.shape[0:2]
|
|
||||||
|
|
||||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
|
||||||
# リサイズする
|
|
||||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
|
||||||
|
|
||||||
image_height, image_width = image.shape[0:2]
|
|
||||||
original_size = (image_width, image_height)
|
|
||||||
|
|
||||||
crop_left_top = (0, 0)
|
|
||||||
if image_width > reso[0]:
|
|
||||||
trim_size = image_width - reso[0]
|
|
||||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
|
||||||
# print("w", trim_size, p)
|
|
||||||
image = image[:, p : p + reso[0]]
|
|
||||||
crop_left_top = (p, 0)
|
|
||||||
if image_height > reso[1]:
|
|
||||||
trim_size = image_height - reso[1]
|
|
||||||
p = trim_size // 2 if not subset.random_crop else random.randint(0, trim_size)
|
|
||||||
# print("h", trim_size, p)
|
|
||||||
image = image[p : p + reso[1]]
|
|
||||||
crop_left_top = (crop_left_top[0], p)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
|
||||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
|
||||||
return image, original_size, crop_left_top
|
|
||||||
|
|
||||||
def is_latent_cacheable(self):
|
def is_latent_cacheable(self):
|
||||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||||
|
|
||||||
@@ -822,26 +792,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
|
||||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
|
||||||
|
|
||||||
for npath in [npz_path, flipped_npz_path]:
|
|
||||||
if npath is None:
|
|
||||||
continue
|
|
||||||
if not os.path.exists(npath):
|
|
||||||
return False
|
|
||||||
|
|
||||||
npz = np.load(npath)
|
|
||||||
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
|
|
||||||
return False
|
|
||||||
|
|
||||||
cached_latents = npz["latents"]
|
|
||||||
|
|
||||||
if cached_latents.shape[1:3] != expected_latents_size:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
|
||||||
# ちょっと速くした
|
# ちょっと速くした
|
||||||
print("caching latents.")
|
print("caching latents.")
|
||||||
@@ -864,13 +814,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
# check disk cache exists and size of latents
|
# check disk cache exists and size of latents
|
||||||
if cache_to_disk:
|
if cache_to_disk:
|
||||||
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
|
||||||
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
|
if not is_main_process: # store to info only
|
||||||
if not is_main_process:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
cache_available = self.is_disk_cached_latents_is_expected(
|
cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug)
|
||||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if cache_available: # do not add to batch
|
if cache_available: # do not add to batch
|
||||||
continue
|
continue
|
||||||
@@ -890,60 +837,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
if len(batch) > 0:
|
if len(batch) > 0:
|
||||||
batches.append(batch)
|
batches.append(batch)
|
||||||
|
|
||||||
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
|
if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only
|
||||||
return
|
return
|
||||||
|
|
||||||
# iterate batches
|
# iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded
|
||||||
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
for batch in tqdm(batches, smoothing=1, total=len(batches)):
|
||||||
images = []
|
cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop)
|
||||||
for info in batch:
|
|
||||||
image = self.load_image(info.absolute_path)
|
|
||||||
image, original_size, crop_left_top = self.trim_and_resize_if_required(
|
|
||||||
subset, image, info.bucket_reso, info.resized_size
|
|
||||||
)
|
|
||||||
image = self.image_transforms(image)
|
|
||||||
images.append(image)
|
|
||||||
|
|
||||||
info.latents_original_size = original_size
|
|
||||||
info.latents_crop_left_top = crop_left_top
|
|
||||||
|
|
||||||
img_tensors = torch.stack(images, dim=0)
|
|
||||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
|
||||||
|
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
|
||||||
|
|
||||||
for info, latent in zip(batch, latents):
|
|
||||||
# check NaN
|
|
||||||
if torch.isnan(latents).any():
|
|
||||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
|
||||||
|
|
||||||
if cache_to_disk:
|
|
||||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top)
|
|
||||||
else:
|
|
||||||
info.latents = latent
|
|
||||||
|
|
||||||
if subset.flip_aug:
|
|
||||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
|
||||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
|
||||||
for info, latent in zip(batch, latents):
|
|
||||||
# check NaN
|
|
||||||
if torch.isnan(latents).any():
|
|
||||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
|
||||||
|
|
||||||
if cache_to_disk:
|
|
||||||
# crop_left_top is reversed when making batch
|
|
||||||
save_latents_to_disk(
|
|
||||||
info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
info.latents_flipped = latent
|
|
||||||
|
|
||||||
def get_image_size(self, image_path):
|
def get_image_size(self, image_path):
|
||||||
image = Image.open(image_path)
|
image = Image.open(image_path)
|
||||||
return image.size
|
return image.size
|
||||||
|
|
||||||
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
def load_image_with_face_info(self, subset: BaseSubset, image_path: str):
|
||||||
img = self.load_image(image_path)
|
img = load_image(image_path)
|
||||||
|
|
||||||
face_cx = face_cy = face_w = face_h = 0
|
face_cx = face_cy = face_w = face_h = 0
|
||||||
if subset.face_crop_aug_range is not None:
|
if subset.face_crop_aug_range is not None:
|
||||||
@@ -1004,10 +910,6 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
|
||||||
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
|
||||||
return load_latents_from_disk(npz_file)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._length
|
return self._length
|
||||||
|
|
||||||
@@ -1016,6 +918,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
bucket_batch_size = self.buckets_indices[index].bucket_batch_size
|
||||||
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
image_index = self.buckets_indices[index].batch_index * bucket_batch_size
|
||||||
|
|
||||||
|
if self.caching_mode is not None: # return batch for latents/text encoder outputs caching
|
||||||
|
return self.get_item_for_caching(bucket, bucket_batch_size, image_index)
|
||||||
|
|
||||||
loss_weights = []
|
loss_weights = []
|
||||||
captions = []
|
captions = []
|
||||||
input_ids_list = []
|
input_ids_list = []
|
||||||
@@ -1045,7 +950,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
image = None
|
image = None
|
||||||
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
|
||||||
latents, original_size, crop_left_top = self.load_latents_from_npz(image_info, flipped)
|
latents, original_size, crop_left_top, flipped_latents = load_latents_from_disk(image_info.latents_npz)
|
||||||
|
if flipped:
|
||||||
|
latents = flipped_latents
|
||||||
|
del flipped_latents
|
||||||
latents = torch.FloatTensor(latents)
|
latents = torch.FloatTensor(latents)
|
||||||
|
|
||||||
image = None
|
image = None
|
||||||
@@ -1055,8 +963,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
im_h, im_w = img.shape[0:2]
|
im_h, im_w = img.shape[0:2]
|
||||||
|
|
||||||
if self.enable_bucket:
|
if self.enable_bucket:
|
||||||
img, original_size, crop_left_top = self.trim_and_resize_if_required(
|
img, original_size, crop_left_top = trim_and_resize_if_required(
|
||||||
subset, img, image_info.bucket_reso, image_info.resized_size
|
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if face_cx > 0: # 顔位置情報あり
|
if face_cx > 0: # 顔位置情報あり
|
||||||
@@ -1162,6 +1070,53 @@ class BaseDataset(torch.utils.data.Dataset):
|
|||||||
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
example["image_keys"] = bucket[image_index : image_index + self.batch_size]
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
def get_item_for_caching(self, bucket, bucket_batch_size, image_index):
|
||||||
|
captions = []
|
||||||
|
images = []
|
||||||
|
absolute_paths = []
|
||||||
|
resized_sizes = []
|
||||||
|
bucket_reso = None
|
||||||
|
flip_aug = None
|
||||||
|
random_crop = None
|
||||||
|
|
||||||
|
for image_key in bucket[image_index : image_index + bucket_batch_size]:
|
||||||
|
image_info = self.image_data[image_key]
|
||||||
|
subset = self.image_to_subset[image_key]
|
||||||
|
|
||||||
|
if flip_aug is None:
|
||||||
|
flip_aug = subset.flip_aug
|
||||||
|
random_crop = subset.random_crop
|
||||||
|
bucket_reso = image_info.bucket_reso
|
||||||
|
else:
|
||||||
|
assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch"
|
||||||
|
assert random_crop == subset.random_crop, "random_crop must be same in a batch"
|
||||||
|
assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch"
|
||||||
|
|
||||||
|
caption = image_info.caption # TODO cache some patterns of droping, shuffling, etc.
|
||||||
|
if self.caching_mode == "latents":
|
||||||
|
image = load_image(image_info.absolute_path)
|
||||||
|
else:
|
||||||
|
image = None
|
||||||
|
|
||||||
|
captions.append(caption)
|
||||||
|
images.append(image)
|
||||||
|
absolute_paths.append(image_info.absolute_path)
|
||||||
|
resized_sizes.append(image_info.resized_size)
|
||||||
|
|
||||||
|
example = {}
|
||||||
|
|
||||||
|
if images[0] is None:
|
||||||
|
images = None
|
||||||
|
example["images"] = images
|
||||||
|
|
||||||
|
example["captions"] = captions
|
||||||
|
example["absolute_paths"] = absolute_paths
|
||||||
|
example["resized_sizes"] = resized_sizes
|
||||||
|
example["flip_aug"] = flip_aug
|
||||||
|
example["random_crop"] = random_crop
|
||||||
|
example["bucket_reso"] = bucket_reso
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
class DreamBoothDataset(BaseDataset):
|
class DreamBoothDataset(BaseDataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1635,11 +1590,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
assert len(missing_imgs) == 0, f"missing conditioning data for {len(missing_imgs)} images: {missing_imgs}"
|
||||||
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
assert len(extra_imgs) == 0, f"extra conditioning data for {len(extra_imgs)} images: {extra_imgs}"
|
||||||
|
|
||||||
self.conditioning_image_transforms = transforms.Compose(
|
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||||
[
|
|
||||||
transforms.ToTensor(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def make_buckets(self):
|
def make_buckets(self):
|
||||||
self.dreambooth_dataset_delegate.make_buckets()
|
self.dreambooth_dataset_delegate.make_buckets()
|
||||||
@@ -1667,7 +1618,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
original_size_hw = example["original_sizes_hw"][i]
|
original_size_hw = example["original_sizes_hw"][i]
|
||||||
crop_top_left = example["crop_top_lefts"][i]
|
crop_top_left = example["crop_top_lefts"][i]
|
||||||
flipped = example["flippeds"][i]
|
flipped = example["flippeds"][i]
|
||||||
cond_img = self.load_image(image_info.cond_img_path)
|
cond_img = load_image(image_info.cond_img_path)
|
||||||
|
|
||||||
if self.dreambooth_dataset_delegate.enable_bucket:
|
if self.dreambooth_dataset_delegate.enable_bucket:
|
||||||
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
cond_img = cv2.resize(cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
@@ -1729,6 +1680,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
print(f"[Dataset {i}]")
|
print(f"[Dataset {i}]")
|
||||||
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
|
||||||
|
|
||||||
|
def set_caching_mode(self, caching_mode):
|
||||||
|
for dataset in self.datasets:
|
||||||
|
dataset.set_caching_mode(caching_mode)
|
||||||
|
|
||||||
def is_latent_cacheable(self) -> bool:
|
def is_latent_cacheable(self) -> bool:
|
||||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||||
|
|
||||||
@@ -1752,28 +1707,53 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
|||||||
dataset.disable_token_padding()
|
dataset.disable_token_padding()
|
||||||
|
|
||||||
|
|
||||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
|
||||||
def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]:
|
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||||
if npz_path is None: # flipped doesn't exist
|
|
||||||
return None, None, None
|
|
||||||
|
|
||||||
|
if not os.path.exists(npz_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
npz = np.load(npz_path)
|
||||||
|
if "latents" not in npz or "original_size" not in npz or "crop_left_top" not in npz: # old ver?
|
||||||
|
return False
|
||||||
|
if npz["latents"].shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if flip_aug:
|
||||||
|
if "latents_flipped" not in npz:
|
||||||
|
return False
|
||||||
|
if npz["latents_flipped"].shape[1:3] != expected_latents_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||||
|
def load_latents_from_disk(
|
||||||
|
npz_path,
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
|
||||||
npz = np.load(npz_path)
|
npz = np.load(npz_path)
|
||||||
if "latents" not in npz:
|
if "latents" not in npz:
|
||||||
print(f"error: npz is old format. please re-generate {npz_path}")
|
print(f"error: npz is old format. please re-generate {npz_path}")
|
||||||
return None, None, None
|
return None, None, None, None
|
||||||
|
|
||||||
latents = npz["latents"]
|
latents = npz["latents"]
|
||||||
original_size = npz["original_size"].tolist()
|
original_size = npz["original_size"].tolist()
|
||||||
crop_left_top = npz["crop_left_top"].tolist()
|
crop_left_top = npz["crop_left_top"].tolist()
|
||||||
return latents, original_size, crop_left_top
|
flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None
|
||||||
|
return latents, original_size, crop_left_top, flipped_latents
|
||||||
|
|
||||||
|
|
||||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top):
|
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top, flipped_latents_tensor=None):
|
||||||
|
kwargs = {}
|
||||||
|
if flipped_latents_tensor is not None:
|
||||||
|
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
|
||||||
np.savez(
|
np.savez(
|
||||||
npz_path,
|
npz_path,
|
||||||
latents=latents_tensor.float().cpu().numpy(),
|
latents=latents_tensor.float().cpu().numpy(),
|
||||||
original_size=np.array(original_size),
|
original_size=np.array(original_size),
|
||||||
crop_left_top=np.array(crop_left_top),
|
crop_left_top=np.array(crop_left_top),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1948,6 +1928,93 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:
|
|||||||
return train_dataset_group
|
return train_dataset_group
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
img = np.array(image, np.uint8)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
|
||||||
|
def trim_and_resize_if_required(
|
||||||
|
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
|
||||||
|
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
|
||||||
|
image_height, image_width = image.shape[0:2]
|
||||||
|
|
||||||
|
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||||
|
# リサイズする
|
||||||
|
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
|
||||||
|
|
||||||
|
image_height, image_width = image.shape[0:2]
|
||||||
|
original_size = (image_width, image_height)
|
||||||
|
|
||||||
|
crop_left_top = (0, 0)
|
||||||
|
if image_width > reso[0]:
|
||||||
|
trim_size = image_width - reso[0]
|
||||||
|
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
|
||||||
|
# print("w", trim_size, p)
|
||||||
|
image = image[:, p : p + reso[0]]
|
||||||
|
crop_left_top = (p, 0)
|
||||||
|
if image_height > reso[1]:
|
||||||
|
trim_size = image_height - reso[1]
|
||||||
|
p = trim_size // 2 if not random_crop else random.randint(0, trim_size)
|
||||||
|
# print("h", trim_size, p)
|
||||||
|
image = image[p : p + reso[1]]
|
||||||
|
crop_left_top = (crop_left_top[0], p)
|
||||||
|
|
||||||
|
assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||||
|
return image, original_size, crop_left_top
|
||||||
|
|
||||||
|
|
||||||
|
def cache_batch_latents(
|
||||||
|
vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool
|
||||||
|
) -> None:
|
||||||
|
r"""
|
||||||
|
requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz
|
||||||
|
optionally requires image_infos to have: image
|
||||||
|
if cache_to_disk is True, set info.latents_npz
|
||||||
|
flipped latents is also saved if flip_aug is True
|
||||||
|
if cache_to_disk is False, set info.latents
|
||||||
|
latents_flipped is also set if flip_aug is True
|
||||||
|
latents_original_size and latents_crop_left_top are also set
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
for info in image_infos:
|
||||||
|
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
|
||||||
|
image, original_size, crop_left_top = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
|
||||||
|
image = IMAGE_TRANSFORMS(image)
|
||||||
|
images.append(image)
|
||||||
|
|
||||||
|
info.latents_original_size = original_size
|
||||||
|
info.latents_crop_left_top = crop_left_top
|
||||||
|
|
||||||
|
img_tensors = torch.stack(images, dim=0)
|
||||||
|
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
|
|
||||||
|
if flip_aug:
|
||||||
|
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||||
|
with torch.no_grad():
|
||||||
|
flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||||
|
else:
|
||||||
|
flipped_latents = [None] * len(latents)
|
||||||
|
|
||||||
|
for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
|
||||||
|
# check NaN
|
||||||
|
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
|
||||||
|
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||||
|
|
||||||
|
if cache_to_disk:
|
||||||
|
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top, flipped_latent)
|
||||||
|
else:
|
||||||
|
info.latents = latent
|
||||||
|
if flip_aug:
|
||||||
|
info.latents_flipped = flipped_latent
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
# region モジュール入れ替え部
|
# region モジュール入れ替え部
|
||||||
@@ -3975,7 +4042,7 @@ def sample_images_common(
|
|||||||
controlnet=controlnet,
|
controlnet=controlnet,
|
||||||
controlnet_image=controlnet_image,
|
controlnet_image=controlnet_image,
|
||||||
)
|
)
|
||||||
|
|
||||||
image = pipeline.latents_to_image(latents)[0]
|
image = pipeline.latents_to_image(latents)[0]
|
||||||
|
|
||||||
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
|
||||||
|
|||||||
193
tools/cache_latents.py
Normal file
193
tools/cache_latents.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
# latentsのdiskへの事前キャッシュを行う / cache latents to disk
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from multiprocessing import Value
|
||||||
|
import os
|
||||||
|
|
||||||
|
from accelerate.utils import set_seed
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from library import config_util
|
||||||
|
from library import train_util
|
||||||
|
from library import sdxl_train_util
|
||||||
|
from library.config_util import (
|
||||||
|
ConfigSanitizer,
|
||||||
|
BlueprintGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cache_to_disk(args: argparse.Namespace) -> None:
|
||||||
|
train_util.prepare_dataset_args(args, True)
|
||||||
|
|
||||||
|
# check cache latents arg
|
||||||
|
assert args.cache_latents_to_disk, "cache_latents_to_disk must be True / cache_latents_to_diskはTrueである必要があります"
|
||||||
|
|
||||||
|
use_dreambooth_method = args.in_json is None
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
set_seed(args.seed) # 乱数系列を初期化する
|
||||||
|
|
||||||
|
# tokenizerを準備する:datasetを動かすために必要
|
||||||
|
if args.sdxl:
|
||||||
|
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
|
||||||
|
tokenizers = [tokenizer1, tokenizer2]
|
||||||
|
else:
|
||||||
|
tokenizer = train_util.load_tokenizer(args)
|
||||||
|
tokenizers = [tokenizer]
|
||||||
|
|
||||||
|
# データセットを準備する
|
||||||
|
if args.dataset_class is None:
|
||||||
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
|
||||||
|
if args.dataset_config is not None:
|
||||||
|
print(f"Load dataset config from {args.dataset_config}")
|
||||||
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
|
ignored = ["train_data_dir", "in_json"]
|
||||||
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
|
print(
|
||||||
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
|
", ".join(ignored)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if use_dreambooth_method:
|
||||||
|
print("Using DreamBooth method.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
|
||||||
|
args.train_data_dir, args.reg_data_dir
|
||||||
|
)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
print("Training with captions.")
|
||||||
|
user_config = {
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"subsets": [
|
||||||
|
{
|
||||||
|
"image_dir": args.train_data_dir,
|
||||||
|
"metadata_file": args.in_json,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizers)
|
||||||
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
else:
|
||||||
|
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizers)
|
||||||
|
|
||||||
|
# datasetのcache_latentsを呼ばなければ、生の画像が返る
|
||||||
|
|
||||||
|
current_epoch = Value("i", 0)
|
||||||
|
current_step = Value("i", 0)
|
||||||
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
|
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
|
||||||
|
|
||||||
|
# acceleratorを準備する
|
||||||
|
print("prepare accelerator")
|
||||||
|
accelerator = train_util.prepare_accelerator(args)
|
||||||
|
|
||||||
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
|
weight_dtype, _ = train_util.prepare_dtype(args)
|
||||||
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
|
# モデルを読み込む
|
||||||
|
print("load model")
|
||||||
|
if args.sdxl:
|
||||||
|
(_, _, _, vae, _, _, _) = sdxl_train_util.load_target_model(args, accelerator, "sdxl", weight_dtype)
|
||||||
|
else:
|
||||||
|
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||||
|
|
||||||
|
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||||
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
vae.eval()
|
||||||
|
|
||||||
|
# dataloaderを準備する
|
||||||
|
train_dataset_group.set_caching_mode("latents")
|
||||||
|
|
||||||
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset_group,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=collater,
|
||||||
|
num_workers=n_workers,
|
||||||
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# acceleratorを使ってモデルを準備する:マルチGPUで使えるようになるはず
|
||||||
|
train_dataloader = accelerator.prepare(train_dataloader)
|
||||||
|
|
||||||
|
# データ取得のためのループ
|
||||||
|
for batch in tqdm(train_dataloader):
|
||||||
|
b_size = len(batch["images"])
|
||||||
|
vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size
|
||||||
|
flip_aug = batch["flip_aug"]
|
||||||
|
random_crop = batch["random_crop"]
|
||||||
|
bucket_reso = batch["bucket_reso"]
|
||||||
|
|
||||||
|
# バッチを分割して処理する
|
||||||
|
for i in range(0, b_size, vae_batch_size):
|
||||||
|
images = batch["images"][i : i + vae_batch_size]
|
||||||
|
absolute_paths = batch["absolute_paths"][i : i + vae_batch_size]
|
||||||
|
resized_sizes = batch["resized_sizes"][i : i + vae_batch_size]
|
||||||
|
|
||||||
|
image_infos = []
|
||||||
|
for i, (image, absolute_path, resized_size) in enumerate(zip(images, absolute_paths, resized_sizes)):
|
||||||
|
image_info = train_util.ImageInfo(absolute_path, 1, "dummy", False, absolute_path)
|
||||||
|
image_info.image = image
|
||||||
|
image_info.bucket_reso = bucket_reso
|
||||||
|
image_info.resized_size = resized_size
|
||||||
|
image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz"
|
||||||
|
|
||||||
|
if args.skip_existing:
|
||||||
|
if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug):
|
||||||
|
print(f"Skipping {image_info.latents_npz} because it already exists.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
image_infos.append(image_info)
|
||||||
|
|
||||||
|
if len(image_infos) > 0:
|
||||||
|
train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop)
|
||||||
|
|
||||||
|
accelerator.wait_for_everyone()
|
||||||
|
accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.")
|
||||||
|
|
||||||
|
|
||||||
|
def setup_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
train_util.add_sd_models_arguments(parser)
|
||||||
|
train_util.add_training_arguments(parser, True)
|
||||||
|
train_util.add_dataset_arguments(parser, True, True, True)
|
||||||
|
config_util.add_config_arguments(parser)
|
||||||
|
parser.add_argument("--sdxl", action="store_true", help="Use SDXL model / SDXLモデルを使用する")
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_half_vae",
|
||||||
|
action="store_true",
|
||||||
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_existing",
|
||||||
|
action="store_true",
|
||||||
|
help="skip images if npz already exists (both normal and flipped exists if flip_aug is enabled) / npzが既に存在する画像をスキップする(flip_aug有効時は通常、反転の両方が存在する画像をスキップ)",
|
||||||
|
)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = setup_parser()
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = train_util.read_config_from_file(args, parser)
|
||||||
|
|
||||||
|
cache_to_disk(args)
|
||||||
Reference in New Issue
Block a user