implementing KeypointExtractor

This commit is contained in:
2023-02-16 17:56:37 +00:00
parent 970dd19702
commit ad7b160c92
4 changed files with 147 additions and 0 deletions

28
dataset.py Normal file
View File

@@ -0,0 +1,28 @@
import torch
import pandas as pd
from PIL import Image
import json
class WLASLDataset(torch.utils.data.Dataset):
def __init__(self, csv_file: str, video_dir: str, subset:str="train", keypoints_file: str = "keypoints.csv", transform=None):
self.df = pd.read_csv(csv_file)
# filter wlasl data by subset
self.df = self.df[self.df["subset"] == subset]
self.video_dir = video_dir
self.transform = transform
self.subset = subset
self.keypoints_file = keypoints_file
def __len__(self):
return len(self.df)
def __getitem__(self, index):
video_id = self.df.iloc[index]["video_id"]
# check if keypoints file exists
if not os.path.exists(self.keypoints_file):
# create empty dataframe
keypoints_df = pd.DataFrame(columns=["video_id", "keypoints"])
# check if keypoints are available else extract from video