support sdxl in prepare scipt

This commit is contained in:
Kohya S
2023-07-07 21:16:41 +09:00
parent 4a34e5804e
commit cc3d40ca44
2 changed files with 77 additions and 44 deletions

View File

@@ -34,12 +34,18 @@ def collate_fn_remove_corrupted(batch):
return batch return batch
def get_latents(vae, images, weight_dtype): def get_latents(vae, key_and_images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for image in images] img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images]
img_tensors = torch.stack(img_tensors) img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype) img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad(): with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy() latents = vae.encode(img_tensors).latent_dist.sample()
# check NaN
for (key, _), latents1 in zip(key_and_images, latents):
if torch.isnan(latents1).any():
raise ValueError(f"NaN detected in latents of {key}")
return latents return latents
@@ -107,24 +113,26 @@ def main(args):
def process_batch(is_last): def process_batch(is_last):
for bucket in bucket_manager.buckets: for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [img for _, img in bucket], weight_dtype) latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype)
assert ( assert (
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8 latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
), f"latent shape {latents.shape}, {bucket[0][1].shape}" ), f"latent shape {latents.shape}, {bucket[0][1].shape}"
for (image_key, _), latent in zip(bucket, latents): for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
np.savez(npz_file_name, latent) train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
# flip # flip
if args.flip_aug: if args.flip_aug:
latents = get_latents(vae, [img[:, ::-1].copy() for _, img in bucket], weight_dtype) # copyがないとTensor変換できない latents = get_latents(
vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype
) # copyがないとTensor変換できない
for (image_key, _), latent in zip(bucket, latents): for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext( npz_file_name = get_npz_filename_wo_ext(
args.train_data_dir, image_key, args.full_path, True, args.recursive args.train_data_dir, image_key, args.full_path, True, args.recursive
) )
np.savez(npz_file_name, latent) train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
else: else:
# remove existing flipped npz # remove existing flipped npz
for image_key, _ in bucket: for image_key, _ in bucket:
@@ -194,7 +202,7 @@ def main(args):
resized_size[0] >= reso[0] and resized_size[1] >= reso[1] resized_size[0] >= reso[0] and resized_size[1] >= reso[1]
), f"internal error resized size is small: {resized_size}, {reso}" ), f"internal error resized size is small: {resized_size}, {reso}"
# 既に存在するファイルがあればshapeを確認して同じならskipする # 既に存在するファイルがあればshapeを確認して同じならskipする
if args.skip_existing: if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"] npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug: if args.flip_aug:
@@ -208,8 +216,12 @@ def main(args):
found = False found = False
break break
dat = np.load(npz_file)["arr_0"] latents, _, _ = train_util.load_latents_from_disk(npz_file)
if dat.shape[1] != reso[1] // 8 or dat.shape[2] != reso[0] // 8: # latentsのshapeを確認 if latents is None: # old version
found = False
break
if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False found = False
break break
if found: if found:
@@ -221,13 +233,21 @@ def main(args):
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要? if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
trim_left = 0
if resized_size[0] > reso[0]: if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0] trim_size = resized_size[0] - reso[0]
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]] image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
trim_left = trim_size // 2
trim_top = 0
if resized_size[1] > reso[1]: if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1] trim_size = resized_size[1] - reso[1]
image = image[trim_size // 2 : trim_size // 2 + reso[1]] image = image[trim_size // 2 : trim_size // 2 + reso[1]]
trim_top = trim_size // 2
original_size_wh = (resized_size[0], resized_size[1])
# target_size_wh = (reso[0], reso[1])
crop_left_top = (trim_left, trim_top)
assert ( assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0] image.shape[0] == reso[1] and image.shape[1] == reso[0]
@@ -237,7 +257,7 @@ def main(args):
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1]) # cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])
# バッチへ追加 # バッチへ追加
bucket_manager.add_image(reso, (image_key, image)) bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top))
# バッチを推論するか判定して推論する # バッチを推論するか判定して推論する
process_batch(False) process_batch(False)

View File

@@ -124,11 +124,11 @@ class BucketManager:
self.resos = [] self.resos = []
self.reso_to_id = {} self.reso_to_id = {}
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key
def add_image(self, reso, image): def add_image(self, reso, image_or_info):
bucket_id = self.reso_to_id[reso] bucket_id = self.reso_to_id[reso]
self.buckets[bucket_id].append(image) self.buckets[bucket_id].append(image_or_info)
def shuffle(self): def shuffle(self):
for bucket in self.buckets: for bucket in self.buckets:
@@ -767,7 +767,10 @@ class BaseDataset(torch.utils.data.Dataset):
img = np.array(image, np.uint8) img = np.array(image, np.uint8)
return img return img
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size): # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
def trim_and_resize_if_required(
self, subset: BaseSubset, image, reso, resized_size
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
image_height, image_width = image.shape[0:2] image_height, image_width = image.shape[0:2]
if image_width != resized_size[0] or image_height != resized_size[1]: if image_width != resized_size[0] or image_height != resized_size[1]:
@@ -907,19 +910,13 @@ class BaseDataset(torch.utils.data.Dataset):
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents):
# check NaN # check NaN
for info, latents1 in zip(batch, latents): if torch.isnan(latents).any():
if torch.isnan(latents1).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
for info, latent in zip(batch, latents):
if cache_to_disk: if cache_to_disk:
np.savez( save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top)
info.latents_npz,
latents=latent.float().numpy(),
original_size=np.array(info.latents_original_size),
crop_left_top=np.array(info.latents_crop_left_top),
)
else: else:
info.latents = latent info.latents = latent
@@ -927,12 +924,14 @@ class BaseDataset(torch.utils.data.Dataset):
img_tensors = torch.flip(img_tensors, dims=[3]) img_tensors = torch.flip(img_tensors, dims=[3])
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
for info, latent in zip(batch, latents): for info, latent in zip(batch, latents):
# check NaN
if torch.isnan(latents).any():
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
if cache_to_disk: if cache_to_disk:
np.savez( # crop_left_top is reversed when making batch
info.latents_npz_flipped, save_latents_to_disk(
latents=latent.float().numpy(), info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top
original_size=np.array(info.latents_original_size),
crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents
) )
else: else:
info.latents_flipped = latent info.latents_flipped = latent
@@ -1005,18 +1004,7 @@ class BaseDataset(torch.utils.data.Dataset):
def load_latents_from_npz(self, image_info: ImageInfo, flipped): def load_latents_from_npz(self, image_info: ImageInfo, flipped):
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
if npz_file is None: return load_latents_from_disk(npz_file)
return None, None, None
npz = np.load(npz_file)
if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_file}")
return None, None, None
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_left_top = npz["crop_left_top"].tolist()
return latents, original_size, crop_left_top
def __len__(self): def __len__(self):
return self._length return self._length
@@ -1762,6 +1750,31 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
dataset.disable_token_padding() dataset.disable_token_padding()
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]:
if npz_path is None: # flipped doesn't exist
return None, None, None
npz = np.load(npz_path)
if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_path}")
return None, None, None
latents = npz["latents"]
original_size = npz["original_size"].tolist()
crop_left_top = npz["crop_left_top"].tolist()
return latents, original_size, crop_left_top
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top):
np.savez(
npz_path,
latents=latents_tensor.float().cpu().numpy(),
original_size=np.array(original_size),
crop_left_top=np.array(crop_left_top),
)
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")