mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor caching latents (flip in same npz, etc)
This commit is contained in:
@@ -34,22 +34,7 @@ def collate_fn_remove_corrupted(batch):
|
||||
return batch
|
||||
|
||||
|
||||
def get_latents(vae, key_and_images, weight_dtype):
|
||||
img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images]
|
||||
img_tensors = torch.stack(img_tensors)
|
||||
img_tensors = img_tensors.to(DEVICE, weight_dtype)
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample()
|
||||
|
||||
# check NaN
|
||||
for (key, _), latents1 in zip(key_and_images, latents):
|
||||
if torch.isnan(latents1).any():
|
||||
raise ValueError(f"NaN detected in latents of {key}")
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
||||
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
|
||||
if is_full_path:
|
||||
base_name = os.path.splitext(os.path.basename(image_key))[0]
|
||||
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
|
||||
@@ -57,13 +42,10 @@ def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
|
||||
base_name = image_key
|
||||
relative_path = ""
|
||||
|
||||
if flip:
|
||||
base_name += "_flip"
|
||||
|
||||
if recursive and relative_path:
|
||||
return os.path.join(data_dir, relative_path, base_name)
|
||||
return os.path.join(data_dir, relative_path, base_name) + ".npz"
|
||||
else:
|
||||
return os.path.join(data_dir, base_name)
|
||||
return os.path.join(data_dir, base_name) + ".npz"
|
||||
|
||||
|
||||
def main(args):
|
||||
@@ -113,36 +95,7 @@ def main(args):
|
||||
def process_batch(is_last):
|
||||
for bucket in bucket_manager.buckets:
|
||||
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
|
||||
latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype)
|
||||
assert (
|
||||
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
|
||||
), f"latent shape {latents.shape}, {bucket[0][1].shape}"
|
||||
|
||||
for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
|
||||
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
|
||||
|
||||
# flip
|
||||
if args.flip_aug:
|
||||
latents = get_latents(
|
||||
vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype
|
||||
) # copyがないとTensor変換できない
|
||||
|
||||
for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
|
||||
npz_file_name = get_npz_filename_wo_ext(
|
||||
args.train_data_dir, image_key, args.full_path, True, args.recursive
|
||||
)
|
||||
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
|
||||
else:
|
||||
# remove existing flipped npz
|
||||
for image_key, _ in bucket:
|
||||
npz_file_name = (
|
||||
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
||||
)
|
||||
if os.path.isfile(npz_file_name):
|
||||
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
|
||||
os.remove(npz_file_name)
|
||||
|
||||
train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
|
||||
bucket.clear()
|
||||
|
||||
# 読み込みの高速化のためにDataLoaderを使うオプション
|
||||
@@ -203,61 +156,18 @@ def main(args):
|
||||
), f"internal error resized size is small: {resized_size}, {reso}"
|
||||
|
||||
# 既に存在するファイルがあればshape等を確認して同じならskipする
|
||||
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
|
||||
if args.skip_existing:
|
||||
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
|
||||
if args.flip_aug:
|
||||
npz_files.append(
|
||||
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
|
||||
)
|
||||
|
||||
found = True
|
||||
for npz_file in npz_files:
|
||||
if not os.path.exists(npz_file):
|
||||
found = False
|
||||
break
|
||||
|
||||
latents, _, _ = train_util.load_latents_from_disk(npz_file)
|
||||
if latents is None: # old version
|
||||
found = False
|
||||
break
|
||||
|
||||
if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認
|
||||
found = False
|
||||
break
|
||||
if found:
|
||||
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
|
||||
continue
|
||||
|
||||
# 画像をリサイズしてトリミングする
|
||||
# PILにinter_areaがないのでcv2で……
|
||||
image = np.array(image)
|
||||
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
|
||||
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
|
||||
|
||||
trim_left = 0
|
||||
if resized_size[0] > reso[0]:
|
||||
trim_size = resized_size[0] - reso[0]
|
||||
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
|
||||
trim_left = trim_size // 2
|
||||
|
||||
trim_top = 0
|
||||
if resized_size[1] > reso[1]:
|
||||
trim_size = resized_size[1] - reso[1]
|
||||
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
|
||||
trim_top = trim_size // 2
|
||||
|
||||
original_size_wh = (resized_size[0], resized_size[1])
|
||||
# target_size_wh = (reso[0], reso[1])
|
||||
crop_left_top = (trim_left, trim_top)
|
||||
|
||||
assert (
|
||||
image.shape[0] == reso[1] and image.shape[1] == reso[0]
|
||||
), f"internal error, illegal trimmed size: {image.shape}, {reso}"
|
||||
|
||||
# # debug
|
||||
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
|
||||
|
||||
# バッチへ追加
|
||||
bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top))
|
||||
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
|
||||
image_info.latents_npz = npz_file_name
|
||||
image_info.bucket_reso = reso
|
||||
image_info.resized_size = resized_size
|
||||
image_info.image = image
|
||||
bucket_manager.add_image(reso, image_info)
|
||||
|
||||
# バッチを推論するか判定して推論する
|
||||
process_batch(False)
|
||||
|
||||
Reference in New Issue
Block a user