Added embeddings

This commit is contained in:
2023-04-14 14:25:33 +02:00
parent ed0e0f198b
commit 8771fe0721
4 changed files with 4514 additions and 39 deletions

Binary file not shown.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,6 @@
from collections import Counter
import cv2
import mediapipe as mp
import numpy as np
@@ -192,7 +194,7 @@ def normalize_hand(keypoints):
# load training embedding csv
df = pd.read_csv('data/fingerspelling/embeddings.csv')
df = pd.read_csv('embeddings/basic-signs/embeddings.csv')
def minkowski_distance_p(x, y, p=2):
x = np.asarray(x)
@@ -248,7 +250,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
return result
CHECKPOINT_PATH = "out_checkpoints/checkpoint_embed_1105.pth"
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = SPOTER_EMBEDDINGS(
@@ -281,51 +283,60 @@ def make_prediction(keypoints):
# calculate distance matrix
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
# find closest match
closest_match = np.argmin(dist_matrix[0])
# get the 5 closest matches and select the class that is most common and use the average distance as the score
# get the 5 closest matches
indeces = np.argsort(dist_matrix)[0][:5]
# get the labels
labels = df["label_name"].iloc[indeces].tolist()
c = Counter(labels).most_common()[0][0]
# if dist_matrix[0][closest_match] < 2:
return df.iloc[closest_match]["label_name"], dist_matrix[0][closest_match]
# filter indeces to only include the most common label
indeces = [i for i in indeces if df["label_name"].iloc[i] == c]
# get the average distance
score = np.mean(dist_matrix[0][indeces])
return c, score
# open webcam stream
cap = cv2.VideoCapture(0)
# cap = cv2.VideoCapture(0)
while cap.isOpened():
# read frame
ret, frame = cap.read()
pose = extract_keypoints(frame)
if pose is None:
cv2.imshow('MediaPipe Hands', frame)
continue
buffer.append(pose)
if len(buffer) > 15:
buffer.pop(0)
if len(buffer) == 15:
label, score = make_prediction(buffer)
# draw label
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# Show the frame
cv2.imshow('MediaPipe Hands', frame)
# Wait for key press to exit
if cv2.waitKey(5) & 0xFF == 27:
break
# open video A.mp4
# cap = cv2.VideoCapture('Z.mp4')
# while cap.isOpened():
# # read frame
# ret, frame = cap.read()
# if frame is None:
# break
# pose = extract_keypoints(frame)
# if pose is None:
# cv2.imshow('MediaPipe Hands', frame)
# continue
# buffer.append(pose)
# if len(buffer) > 15:
# buffer.pop(0)
# if len(buffer) == 15:
# label, score = make_prediction(buffer)
# # draw label
# cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# # Show the frame
# cv2.imshow('MediaPipe Hands', frame)
# # Wait for key press to exit
# if cv2.waitKey(5) & 0xFF == 27:
# break
# open video A.mp4
cap = cv2.VideoCapture('E.mp4')
while cap.isOpened():
# read frame
ret, frame = cap.read()
if frame is None:
break
pose = extract_keypoints(frame)
buffer.append(pose)
# make_prediction(buffer)
label, score = make_prediction(buffer)
print(label, score)