Files
sign-predictor/src/datasets/finger_spelling_dataset.py
2023-03-12 19:34:04 +00:00

79 lines
2.8 KiB
Python

import os
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from src.identifiers import LANDMARKS
from src.keypoint_extractor import KeypointExtractor
class FingerSpellingDataset(torch.utils.data.Dataset):
def __init__(self, data_folder: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
# list data from data folder
self.data_folder = data_folder
# list files in the datafolder ending with .mp4
files = [f for f in os.listdir(self.data_folder) if f.endswith(".mp4")]
labels = [f.split("!")[0] for f in files]
# count the number of each label
self.label_mapping, counts = np.unique(labels, return_counts=True)
# save the label mapping to a file
with open(os.path.join(self.data_folder, "label_mapping.txt"), "w") as f:
for i, label in enumerate(self.label_mapping):
f.write(f"{label} {i}")
# map the labels to their integer
labels = [np.where(self.label_mapping == label)[0][0] for label in labels]
# TODO: make split for train and val and test when enough data is available
# split the data into train and val and test and make them balanced
x_train, x_test, y_train, y_test = train_test_split(files, labels, test_size=0.3, random_state=1, stratify=labels)
if subset == "train":
self.data = x_train
self.labels = y_train
elif subset == "val":
self.data = x_test
self.labels = y_test
# filter wlasl data by subset
self.transform = transform
self.subset = subset
self.keypoint_extractor = keypoint_extractor
if keypoints_identifier:
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# get i th element from ordered dict
video_name = self.data[index]
# get the keypoints for the video
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name, normalize="minxmax")
# filter the keypoints by the identified subset
if self.keypoints_to_keep:
keypoints_df = keypoints_df[self.keypoints_to_keep]
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
for i in range(0, keypoints_df.shape[1], 2):
current_row[:, i//2, 0] = keypoints_df.iloc[:,i]
current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1]
label = self.labels[index]
# data to tensor
data = torch.from_numpy(current_row)
if self.transform:
data = self.transform(data)
return data, label