96 lines
3.4 KiB
Python
96 lines
3.4 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import torch
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
from src.identifiers import LANDMARKS
|
|
from src.keypoint_extractor import KeypointExtractor
|
|
|
|
|
|
class FingerSpellingDataset(torch.utils.data.Dataset):
|
|
def __init__(self, data_folder: str, bad_data_folder: str = "", subset:str="train", keypoints_identifier: dict = None, transform=None):
|
|
|
|
|
|
# list files with path in the datafolder ending with .mp4
|
|
files = [data_folder + f for f in os.listdir(data_folder) if f.endswith(".mp4")]
|
|
|
|
# append files from bad data folder
|
|
if bad_data_folder != "":
|
|
files += [bad_data_folder + f for f in os.listdir(bad_data_folder) if f.endswith(".mp4")]
|
|
|
|
labels = [f.split("/")[-1].split("!")[0] for f in files]
|
|
train_test = [f.split("/")[-1].split("!")[1] for f in files]
|
|
|
|
# count the number of each label
|
|
self.label_mapping, counts = np.unique(labels, return_counts=True)
|
|
|
|
|
|
|
|
# map the labels to their integer
|
|
labels = [np.where(self.label_mapping == label)[0][0] for label in labels]
|
|
|
|
|
|
# TODO: make split for train and val and test when enough data is available
|
|
if subset == "train":
|
|
# mask for train data
|
|
mask = np.array(train_test) == "train"
|
|
elif subset == "test":
|
|
mask = np.array(train_test) == "test"
|
|
|
|
# filter data and labels
|
|
self.data = np.array(files)[mask]
|
|
self.labels = np.array(labels)[mask]
|
|
|
|
# filter wlasl data by subset
|
|
self.transform = transform
|
|
self.subset = subset
|
|
self.keypoint_extractor = KeypointExtractor()
|
|
if keypoints_identifier:
|
|
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
# get i th element from ordered dict
|
|
video_name = self.data[index]
|
|
|
|
cache_name = video_name.split("/")[-1].split(".")[0] + ".npy"
|
|
|
|
# check if cache_name file exists
|
|
if not os.path.isfile(os.path.join("cache_processed", cache_name)):
|
|
|
|
|
|
# get the keypoints for the video (normalizations: minxmax, bohacek)
|
|
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name, normalize="bohacek")
|
|
|
|
# filter the keypoints by the identified subset
|
|
if self.keypoints_to_keep:
|
|
keypoints_df = keypoints_df[self.keypoints_to_keep]
|
|
|
|
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
|
|
for i in range(0, keypoints_df.shape[1], 2):
|
|
current_row[:, i // 2, 0] = keypoints_df.iloc[:, i]
|
|
current_row[:, i // 2, 1] = keypoints_df.iloc[:, i + 1]
|
|
|
|
# check if cache_processed folder exists
|
|
if not os.path.isdir("cache_processed"):
|
|
os.mkdir("cache_processed")
|
|
|
|
# save the processed data to a file
|
|
np.save(os.path.join("cache_processed", cache_name), current_row)
|
|
|
|
else:
|
|
current_row = np.load(os.path.join("cache_processed", cache_name))
|
|
|
|
# get the label
|
|
label = self.labels[index]
|
|
# data to tensor
|
|
data = torch.from_numpy(current_row)
|
|
|
|
if self.transform:
|
|
data = self.transform(data)
|
|
|
|
return data, label
|