66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
import json
|
|
from collections import OrderedDict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from identifiers import LANDMARKS
|
|
from keypoint_extractor import KeypointExtractor
|
|
|
|
|
|
class WLASLDataset(torch.utils.data.Dataset):
|
|
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
|
|
|
|
# read the missing video file
|
|
with open(missing) as f:
|
|
missing = f.read().splitlines()
|
|
|
|
# read the json file
|
|
with open(json_file) as f:
|
|
data = json.load(f)
|
|
|
|
# remove the missing videos
|
|
for m in missing:
|
|
if m in data:
|
|
del data[m]
|
|
|
|
new_data = OrderedDict()
|
|
for k, v in data.items():
|
|
if v["subset"] == subset:
|
|
new_data[k] = v
|
|
|
|
self.data = new_data
|
|
|
|
# filter wlasl data by subset
|
|
self.transform = transform
|
|
self.subset = subset
|
|
self.keypoint_extractor = keypoint_extractor
|
|
if keypoints_identifier:
|
|
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, index):
|
|
# get i th element from ordered dict
|
|
video_id = list(self.data.keys())[index]
|
|
video_name = f"{video_id}.mp4"
|
|
|
|
# get the keypoints for the video
|
|
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name)
|
|
|
|
# filter the keypoints by the identified subset
|
|
if self.keypoints_to_keep:
|
|
keypoints_df = keypoints_df[self.keypoints_to_keep]
|
|
|
|
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
|
|
for i in range(0, keypoints_df.shape[1], 2):
|
|
current_row[:, i//2, 0] = keypoints_df.iloc[:,i]
|
|
current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1]
|
|
|
|
label = self.data[video_id]["action"][0]
|
|
|
|
# data to tensor
|
|
data = torch.from_numpy(current_row)
|
|
|
|
return data, label |