fix images are used twice, update debug dataset

This commit is contained in:
Kohya S
2023-03-27 20:48:21 +09:00
parent 43a08b4061
commit 238f01bc9c

View File

@@ -481,8 +481,7 @@ class BaseDataset(torch.utils.data.Dataset):
else: else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0: if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(",")] tokens = [t.strip() for t in caption.strip().split(",")]
print(subset.token_warmup_min, subset.token_warmup_step) if subset.token_warmup_step < 1: # 初回に上書きする
if subset.token_warmup_step < 1:
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps) subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step: if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = ( tokens_len = (
@@ -1342,50 +1341,55 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
def debug_dataset(train_dataset, show_input_ids=False): def debug_dataset(train_dataset, show_input_ids=False):
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}") print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
print("`E` for increment (pseudo) epoch no. , Escape for exit. / Eキーで疑似的にエポック番号を+1、Escキーで中断、終了します") print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
epoch = 1 epoch = 1
steps = 1 while True:
train_dataset.set_current_epoch(epoch) print(f"epoch: {epoch}")
train_dataset.set_current_step(steps)
k = 0 steps = (epoch - 1) * len(train_dataset) + 1
indices = list(range(len(train_dataset))) indices = list(range(len(train_dataset)))
random.shuffle(indices) random.shuffle(indices)
for i, idx in enumerate(indices):
example = train_dataset[idx]
if example["latents"] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
if show_input_ids:
print(f"input ids: {iid}")
if example["images"] is not None:
im = example["images"][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27:
break
if k == ord("e"):
epoch += 1
steps = len(train_dataset) * (epoch - 1)
train_dataset.set_current_epoch(epoch)
print(f"epoch: {epoch}")
steps += 1 k = 0
train_dataset.set_current_step(steps) for i, idx in enumerate(indices):
train_dataset.set_current_epoch(epoch)
train_dataset.set_current_step(steps)
print(f"steps: {steps} ({i + 1}/{len(train_dataset)})")
if k == 27 or (example["images"] is None and i >= 8): example = train_dataset[idx]
if example["latents"] is not None:
print(f"sample has latents from npz file: {example['latents'].size()}")
for j, (ik, cap, lw, iid) in enumerate(
zip(example["image_keys"], example["captions"], example["loss_weights"], example["input_ids"])
):
print(f'{ik}, size: {train_dataset.image_data[ik].image_size}, loss weight: {lw}, caption: "{cap}"')
if show_input_ids:
print(f"input ids: {iid}")
if example["images"] is not None:
im = example["images"][j]
print(f"image size: {im.size()}")
im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
if os.name == "nt": # only windows
cv2.imshow("img", im)
k = cv2.waitKey()
cv2.destroyAllWindows()
if k == 27 or k == ord("s") or k == ord("e"):
break
steps += 1
if k == ord("e"):
break
if k == 27 or (example["images"] is None and i >= 8):
k = 27
break
if k == 27:
break break
epoch += 1
def glob_images(directory, base="*"): def glob_images(directory, base="*"):
img_paths = [] img_paths = []
@@ -1394,8 +1398,8 @@ def glob_images(directory, base="*"):
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
else: else:
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
# img_paths = list(set(img_paths)) # 重複を排除 img_paths = list(set(img_paths)) # 重複を排除
# img_paths.sort() img_paths.sort()
return img_paths return img_paths
@@ -1407,8 +1411,8 @@ def glob_images_pathlib(dir_path, recursive):
else: else:
for ext in IMAGE_EXTENSIONS: for ext in IMAGE_EXTENSIONS:
image_paths += list(dir_path.glob("*" + ext)) image_paths += list(dir_path.glob("*" + ext))
# image_paths = list(set(image_paths)) # 重複を排除 image_paths = list(set(image_paths)) # 重複を排除
# image_paths.sort() image_paths.sort()
return image_paths return image_paths