import json from collections import OrderedDict import numpy as np import torch from src.identifiers import LANDMARKS from src.keypoint_extractor import KeypointExtractor class WLASLDataset(torch.utils.data.Dataset): def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None): # read the missing video file with open(missing) as f: missing = f.read().splitlines() # read the json file with open(json_file) as f: data = json.load(f) # remove the missing videos for m in missing: if m in data: del data[m] new_data = OrderedDict() for k, v in data.items(): if v["subset"] == subset: new_data[k] = v self.data = new_data # filter wlasl data by subset self.transform = transform self.subset = subset self.keypoint_extractor = keypoint_extractor if keypoints_identifier: self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]] def __len__(self): return len(self.data) def __getitem__(self, index): # get i th element from ordered dict video_id = list(self.data.keys())[index] video_name = f"{video_id}.mp4" # get the keypoints for the video keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name) # filter the keypoints by the identified subset if self.keypoints_to_keep: keypoints_df = keypoints_df[self.keypoints_to_keep] current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2)) for i in range(0, keypoints_df.shape[1], 2): current_row[:, i//2, 0] = keypoints_df.iloc[:,i] current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1] label = self.data[video_id]["action"][0] # data to tensor data = torch.from_numpy(current_row) return data, label