Merge branches 'dev' and 'dev' of https://gitlab.ilabt.imec.be/wesign/sign-predictor into dev
This commit is contained in:
@@ -4,8 +4,8 @@ import numpy as np
|
||||
import torch
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from identifiers import LANDMARKS
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from src.identifiers import LANDMARKS
|
||||
from src.keypoint_extractor import KeypointExtractor
|
||||
|
||||
|
||||
class FingerSpellingDataset(torch.utils.data.Dataset):
|
||||
|
||||
@@ -4,8 +4,8 @@ from collections import OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from identifiers import LANDMARKS
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from src.identifiers import LANDMARKS
|
||||
from src.keypoint_extractor import KeypointExtractor
|
||||
|
||||
|
||||
class WLASLDataset(torch.utils.data.Dataset):
|
||||
|
||||
@@ -151,25 +151,34 @@ class KeypointExtractor:
|
||||
return results
|
||||
|
||||
|
||||
def normalize_hands(self, dataframe: pd.DataFrame) -> pd.DataFrame:
|
||||
def normalize_hands(self, dataframe: pd.DataFrame, norm_algorithm: str="minmax") -> pd.DataFrame:
|
||||
"""normalize_hand this function normalizes the hand keypoints of a dataframe
|
||||
|
||||
:param dataframe: the dataframe to normalize
|
||||
:type dataframe: pd.DataFrame
|
||||
:param norm_algorithm: the normalization algorithm to use, pick from "minmax" and "bohacek"
|
||||
:type norm_algorithm: str
|
||||
:return: the normalized dataframe
|
||||
:rtype: pd.DataFrame
|
||||
"""
|
||||
|
||||
# normalize left hand
|
||||
dataframe = self.normalize_hand_helper(dataframe, "left_hand")
|
||||
|
||||
# normalize right hand
|
||||
dataframe = self.normalize_hand_helper(dataframe, "right_hand")
|
||||
if norm_algorithm == "minmax":
|
||||
# normalize left hand
|
||||
dataframe = self.normalize_hand_minmax(dataframe, "left_hand")
|
||||
# normalize right hand
|
||||
dataframe = self.normalize_hand_minmax(dataframe, "right_hand")
|
||||
elif norm_algorithm == "bohacek":
|
||||
# normalize left hand
|
||||
dataframe = self.normalize_hand_bohacek(dataframe, "left_hand")
|
||||
# normalize right hand
|
||||
dataframe = self.normalize_hand_bohacek(dataframe, "right_hand")
|
||||
else:
|
||||
return dataframe
|
||||
|
||||
return dataframe
|
||||
|
||||
def normalize_hand_helper(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame:
|
||||
"""normalize_hand_helper this function normalizes the hand keypoints of a dataframe
|
||||
def normalize_hand_minmax(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame:
|
||||
"""normalize_hand_helper this function normalizes the hand keypoints of a dataframe with respect to the minimum and maximum coordinates
|
||||
|
||||
:param dataframe: the dataframe to normalize
|
||||
:type dataframe: pd.DataFrame
|
||||
@@ -194,9 +203,66 @@ class KeypointExtractor:
|
||||
# calculate the width and height of the bounding box around the hand keypoints
|
||||
bbox_width, bbox_height = max_x - min_x, max_y - min_y
|
||||
|
||||
# repeat the center coordinates and bounding box dimensions to match the shape of hand_coords (numpy magic)
|
||||
center_x, center_y = center_x.reshape(-1, 1, 1), center_y.reshape(-1, 1, 1)
|
||||
center_coords = np.concatenate((np.tile(center_x, (1, 21, 1)), np.tile(center_y, (1, 21, 1))), axis=2)
|
||||
|
||||
bbox_width, bbox_height = bbox_width.reshape(-1, 1, 1), bbox_height.reshape(-1, 1 ,1)
|
||||
bbox_dims = np.concatenate((np.tile(bbox_width, (1, 21, 1)), np.tile(bbox_height, (1, 21, 1))), axis=2)
|
||||
|
||||
if np.any(bbox_dims == 0):
|
||||
return dataframe
|
||||
# normalize the hand keypoints based on the bounding box around the hand
|
||||
norm_hand_coords = (hand_coords - center_coords) / bbox_dims
|
||||
|
||||
# flatten the normalized hand keypoints array and replace the original hand keypoints with the normalized hand keypoints in the dataframe
|
||||
dataframe.iloc[:, hand_columns] = norm_hand_coords.reshape(-1, 42)
|
||||
|
||||
return dataframe
|
||||
|
||||
def normalize_hand_bohacek(self, dataframe: pd.DataFrame, hand: str) -> pd.DataFrame:
|
||||
"""normalize_hand_helper this function normalizes the hand keypoints of a dataframe using the bohacek normalization algorithm
|
||||
|
||||
:param dataframe: the dataframe to normalize
|
||||
:type dataframe: pd.DataFrame
|
||||
:param hand: the hand to normalize
|
||||
:type hand: str
|
||||
:return: the normalized dataframe
|
||||
:rtype: pd.DataFrame
|
||||
"""
|
||||
# get all columns that belong to the hand (left hand column 66 - 107, right hand column 108 - 149)
|
||||
hand_columns = np.array([i for i in range(66 + (42 if hand == "right_hand" else 0), 108 + (42 if hand == "right_hand" else 0))])
|
||||
|
||||
# get the x, y coordinates of the hand keypoints
|
||||
hand_coords = dataframe.iloc[:, hand_columns].values.reshape(-1, 21, 2)
|
||||
|
||||
# get the min and max x, y coordinates of the hand keypoints
|
||||
min_x, min_y = np.min(hand_coords[:, :, 0], axis=1), np.min(hand_coords[:, :, 1], axis=1)
|
||||
max_x, max_y = np.max(hand_coords[:, :, 0], axis=1), np.max(hand_coords[:, :, 1], axis=1)
|
||||
|
||||
# calculate the deltas
|
||||
width, height = max_x - min_x, max_y - min_y
|
||||
if width > height:
|
||||
delta_x = 0.1 * width
|
||||
delta_y = delta_x + ((width - height) / 2)
|
||||
else:
|
||||
delta_y = 0.1 * height
|
||||
delta_x = delta_y + ((height - width) / 2)
|
||||
|
||||
# Set the starting and ending point of the normalization bounding box
|
||||
starting_x, starting_y = min_x - delta_x, min_y - delta_y
|
||||
ending_x, ending_y = max_x + delta_x, max_y + delta_y
|
||||
|
||||
# calculate the center of the bounding box and the bounding box dimensions
|
||||
bbox_center_x, bbox_center_y = (starting_x + ending_x) / 2, (starting_y + ending_y) / 2
|
||||
bbox_width, bbox_height = starting_x - ending_x, starting_y - ending_y
|
||||
|
||||
# repeat the center coordinates and bounding box dimensions to match the shape of hand_coords
|
||||
center_coords = np.tile(np.array([center_x, center_y]), (21, 1)).reshape(-1, 21, 2)
|
||||
bbox_dims = np.tile(np.array([bbox_width, bbox_height]), (21, 1)).reshape(-1, 21, 2)
|
||||
center_x, center_y = center_x.reshape(-1, 1, 1), center_y.reshape(-1, 1, 1)
|
||||
center_coords = np.concatenate((np.tile(bbox_center_x, (1, 21, 1)), np.tile(bbox_center_y, (1, 21, 1))), axis=2)
|
||||
|
||||
bbox_width, bbox_height = bbox_width.reshape(-1, 1, 1), bbox_height.reshape(-1, 1 ,1)
|
||||
bbox_dims = np.concatenate((np.tile(bbox_width, (1, 21, 1)), np.tile(bbox_height, (1, 21, 1))), axis=2)
|
||||
|
||||
if np.any(bbox_dims == 0):
|
||||
return dataframe
|
||||
|
||||
19
src/train.py
19
src/train.py
@@ -13,12 +13,12 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
from augmentations import MirrorKeypoints
|
||||
from datasets.finger_spelling_dataset import FingerSpellingDataset
|
||||
from datasets.wlasl_dataset import WLASLDataset
|
||||
from identifiers import LANDMARKS
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from model import SPOTER
|
||||
from src.augmentations import MirrorKeypoints
|
||||
from src.datasets.finger_spelling_dataset import FingerSpellingDataset
|
||||
from src.datasets.wlasl_dataset import WLASLDataset
|
||||
from src.identifiers import LANDMARKS
|
||||
from src.keypoint_extractor import KeypointExtractor
|
||||
from src.model import SPOTER
|
||||
|
||||
|
||||
def train():
|
||||
@@ -81,10 +81,7 @@ def train():
|
||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
||||
pred_correct += 1
|
||||
pred_all += 1
|
||||
|
||||
# if i % 100 == 0:
|
||||
# print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
|
||||
|
||||
|
||||
if scheduler:
|
||||
scheduler.step(running_loss.item() / len(train_loader))
|
||||
|
||||
@@ -107,7 +104,7 @@ def train():
|
||||
|
||||
|
||||
# save checkpoint
|
||||
if val_acc > top_val_acc:
|
||||
if val_acc > top_val_acc and epoch > 55:
|
||||
top_val_acc = val_acc
|
||||
top_train_acc = train_acc
|
||||
checkpoint_index = epoch
|
||||
|
||||
Reference in New Issue
Block a user