import os import numpy as np import torch from sklearn.model_selection import train_test_split from src.identifiers import LANDMARKS from src.keypoint_extractor import KeypointExtractor class FingerSpellingDataset(torch.utils.data.Dataset): def __init__(self, data_folder: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None): # list data from data folder self.data_folder = data_folder # list files in the datafolder ending with .mp4 files = [f for f in os.listdir(self.data_folder) if f.endswith(".mp4")] labels = [f.split("!")[0] for f in files] # count the number of each label self.label_mapping, counts = np.unique(labels, return_counts=True) # save the label mapping to a file with open(os.path.join(self.data_folder, "label_mapping.txt"), "w") as f: for i, label in enumerate(self.label_mapping): f.write(f"{label} {i}") # map the labels to their integer labels = [np.where(self.label_mapping == label)[0][0] for label in labels] # TODO: make split for train and val and test when enough data is available # split the data into train and val and test and make them balanced x_train, x_test, y_train, y_test = train_test_split(files, labels, test_size=0.3, random_state=1, stratify=labels) if subset == "train": self.data = x_train self.labels = y_train elif subset == "val": self.data = x_test self.labels = y_test # filter wlasl data by subset self.transform = transform self.subset = subset self.keypoint_extractor = keypoint_extractor if keypoints_identifier: self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]] def __len__(self): return len(self.data) def __getitem__(self, index): # get i th element from ordered dict video_name = self.data[index] # get the keypoints for the video keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name, normalize="minxmax") # filter the keypoints by the identified subset if self.keypoints_to_keep: keypoints_df = keypoints_df[self.keypoints_to_keep] current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2)) for i in range(0, keypoints_df.shape[1], 2): current_row[:, i//2, 0] = keypoints_df.iloc[:,i] current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1] label = self.labels[index] # data to tensor data = torch.from_numpy(current_row) if self.transform: data = self.transform(data) return data, label