Merge branch 'own_data_dataset' into 'main'
Implement pytorch dataset for own collected data See merge request wesign/sign-predictor!3
This commit was merged in pull request #3.
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,5 +1,8 @@
|
|||||||
.devcontainer/
|
.devcontainer/
|
||||||
data/
|
data/
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
cache/
|
cache/
|
||||||
__pycache__/
|
cache_wlasl/
|
||||||
|
|
||||||
|
__pycache__/
|
||||||
|
|||||||
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/datasets/__init__.py
Normal file
0
src/datasets/__init__.py
Normal file
76
src/datasets/finger_spelling_dataset.py
Normal file
76
src/datasets/finger_spelling_dataset.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
|
from src.identifiers import LANDMARKS
|
||||||
|
from src.keypoint_extractor import KeypointExtractor
|
||||||
|
|
||||||
|
|
||||||
|
class FingerSpellingDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, data_folder: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
|
||||||
|
|
||||||
|
# list data from data folder
|
||||||
|
self.data_folder = data_folder
|
||||||
|
|
||||||
|
# list files in the datafolder ending with .mp4
|
||||||
|
files = [f for f in os.listdir(self.data_folder) if f.endswith(".mp4")]
|
||||||
|
|
||||||
|
labels = [f.split("!")[0] for f in files]
|
||||||
|
|
||||||
|
# count the number of each label
|
||||||
|
self.label_mapping, counts = np.unique(labels, return_counts=True)
|
||||||
|
|
||||||
|
# save the label mapping to a file
|
||||||
|
with open(os.path.join(self.data_folder, "label_mapping.txt"), "w") as f:
|
||||||
|
for i, label in enumerate(self.label_mapping):
|
||||||
|
f.write(f"{label} {i}")
|
||||||
|
|
||||||
|
# map the labels to their integer
|
||||||
|
labels = [np.where(self.label_mapping == label)[0][0] for label in labels]
|
||||||
|
|
||||||
|
# TODO: make split for train and val and test when enough data is available
|
||||||
|
|
||||||
|
# split the data into train and val and test and make them balanced
|
||||||
|
x_train, x_test, y_train, y_test = train_test_split(files, labels, test_size=0.4, random_state=1, stratify=labels)
|
||||||
|
|
||||||
|
if subset == "train":
|
||||||
|
self.data = x_train
|
||||||
|
self.labels = y_train
|
||||||
|
elif subset == "val":
|
||||||
|
self.data = x_test
|
||||||
|
self.labels = y_test
|
||||||
|
|
||||||
|
# filter wlasl data by subset
|
||||||
|
self.transform = transform
|
||||||
|
self.subset = subset
|
||||||
|
self.keypoint_extractor = keypoint_extractor
|
||||||
|
if keypoints_identifier:
|
||||||
|
self.keypoints_to_keep = [f"{i}_{j}" for i in keypoints_identifier.values() for j in ["x", "y"]]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# get i th element from ordered dict
|
||||||
|
video_name = self.data[index]
|
||||||
|
|
||||||
|
# get the keypoints for the video
|
||||||
|
keypoints_df = self.keypoint_extractor.extract_keypoints_from_video(video_name)
|
||||||
|
|
||||||
|
# filter the keypoints by the identified subset
|
||||||
|
if self.keypoints_to_keep:
|
||||||
|
keypoints_df = keypoints_df[self.keypoints_to_keep]
|
||||||
|
|
||||||
|
current_row = np.empty(shape=(keypoints_df.shape[0], keypoints_df.shape[1] // 2, 2))
|
||||||
|
for i in range(0, keypoints_df.shape[1], 2):
|
||||||
|
current_row[:, i//2, 0] = keypoints_df.iloc[:,i]
|
||||||
|
current_row[:, i//2, 1] = keypoints_df.iloc[:,i+1]
|
||||||
|
|
||||||
|
label = self.labels[index]
|
||||||
|
|
||||||
|
# data to tensor
|
||||||
|
data = torch.from_numpy(current_row)
|
||||||
|
|
||||||
|
return data, label
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
import mediapipe as mp
|
|
||||||
import cv2
|
|
||||||
import time
|
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import mediapipe as mp
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class KeypointExtractor:
|
class KeypointExtractor:
|
||||||
def __init__(self, video_folder: str, cache_folder: str = "cache"):
|
def __init__(self, video_folder: str, cache_folder: str = "cache"):
|
||||||
self.mp_drawing = mp.solutions.drawing_utils
|
self.mp_drawing = mp.solutions.drawing_utils
|
||||||
@@ -52,7 +54,18 @@ class KeypointExtractor:
|
|||||||
|
|
||||||
keypoints_df = pd.DataFrame(columns=self.columns)
|
keypoints_df = pd.DataFrame(columns=self.columns)
|
||||||
|
|
||||||
|
# extract frames from video so we extract 5 frames per second
|
||||||
|
frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
|
||||||
|
frame_skip = frame_rate // 5
|
||||||
|
|
||||||
while cap.isOpened():
|
while cap.isOpened():
|
||||||
|
|
||||||
|
# skip frames
|
||||||
|
for _ in range(frame_skip):
|
||||||
|
success, image = cap.read()
|
||||||
|
if not success:
|
||||||
|
break
|
||||||
|
|
||||||
success, image = cap.read()
|
success, image = cap.read()
|
||||||
if not success:
|
if not success:
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import torch.optim as optim
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from dataset import WLASLDataset
|
from datasets.wlasl_dataset import WLASLDataset
|
||||||
from identifiers import LANDMARKS
|
from identifiers import LANDMARKS
|
||||||
from keypoint_extractor import KeypointExtractor
|
from keypoint_extractor import KeypointExtractor
|
||||||
from model import SPOTER
|
from model import SPOTER
|
||||||
|
|||||||
Reference in New Issue
Block a user