Files
sign-predictor/src/datasets/wlasl_dataset.py
2023-02-27 13:34:26 +00:00

66 lines
2.1 KiB
Python

import json
from collections import OrderedDict
import numpy as np
import torch
from identifiers import LANDMARKS
from keypoint_extractor import KeypointExtractor
class WLASLDataset(torch.utils.data.Dataset):
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
# read the missing video file
with open(missing) as f:
missing = f.read().splitlines()
# read the json file
with open(json_file) as f:
data = json.load(f)
# remove the missing videos
for m in missing:
if m in data:
del data[m]
new_data = OrderedDict()
for k, v in data.items():
if v["subset"] == subset:
new_data[k] = v
self.data = new_data
# filter wlasl data by subset
self.transform = transform
self.subset = subset
self.keypoint_extractor = keypoint_extractor
if keypoints_identifier:
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
# get i th element from ordered dict
video_id = list(self.data.keys())[index]
video_name = f"{video_id}.mp4"
# get the keypoints for the video
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name)
# filter the keypoints by the identified subset
if self.keypoints_to_keep:
keypoints_df = keypoints_df[self.keypoints_to_keep]
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
for i in range(0, keypoints_df.shape[1], 2):
current_row[:, i//2, 0] = keypoints_df.iloc[:,i]
current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1]
label = self.data[video_id]["action"][0]
# data to tensor
data = torch.from_numpy(current_row)
return data, label