Added embeddings

This commit is contained in:
2023-04-14 14:25:33 +02:00
parent ed0e0f198b
commit 8771fe0721
4 changed files with 4514 additions and 39 deletions

View File

@@ -1,4 +1,6 @@
from collections import Counter
import cv2
import mediapipe as mp
import numpy as np
@@ -192,7 +194,7 @@ def normalize_hand(keypoints):
# load training embedding csv
df = pd.read_csv('data/fingerspelling/embeddings.csv')
df = pd.read_csv('embeddings/basic-signs/embeddings.csv')
def minkowski_distance_p(x, y, p=2):
x = np.asarray(x)
@@ -248,7 +250,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
return result
CHECKPOINT_PATH = "out_checkpoints/checkpoint_embed_1105.pth"
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model = SPOTER_EMBEDDINGS(
@@ -281,51 +283,60 @@ def make_prediction(keypoints):
# calculate distance matrix
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
# find closest match
closest_match = np.argmin(dist_matrix[0])
# get the 5 closest matches and select the class that is most common and use the average distance as the score
# get the 5 closest matches
indeces = np.argsort(dist_matrix)[0][:5]
# get the labels
labels = df["label_name"].iloc[indeces].tolist()
c = Counter(labels).most_common()[0][0]
# if dist_matrix[0][closest_match] < 2:
return df.iloc[closest_match]["label_name"], dist_matrix[0][closest_match]
# filter indeces to only include the most common label
indeces = [i for i in indeces if df["label_name"].iloc[i] == c]
# get the average distance
score = np.mean(dist_matrix[0][indeces])
return c, score
# open webcam stream
cap = cv2.VideoCapture(0)
# cap = cv2.VideoCapture(0)
while cap.isOpened():
# read frame
ret, frame = cap.read()
pose = extract_keypoints(frame)
if pose is None:
cv2.imshow('MediaPipe Hands', frame)
continue
buffer.append(pose)
if len(buffer) > 15:
buffer.pop(0)
if len(buffer) == 15:
label, score = make_prediction(buffer)
# draw label
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# Show the frame
cv2.imshow('MediaPipe Hands', frame)
# Wait for key press to exit
if cv2.waitKey(5) & 0xFF == 27:
break
# open video A.mp4
# cap = cv2.VideoCapture('Z.mp4')
# while cap.isOpened():
# # read frame
# ret, frame = cap.read()
# if frame is None:
# break
# pose = extract_keypoints(frame)
# if pose is None:
# cv2.imshow('MediaPipe Hands', frame)
# continue
# buffer.append(pose)
# if len(buffer) > 15:
# buffer.pop(0)
# if len(buffer) == 15:
# label, score = make_prediction(buffer)
# # draw label
# cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
# # Show the frame
# cv2.imshow('MediaPipe Hands', frame)
# # Wait for key press to exit
# if cv2.waitKey(5) & 0xFF == 27:
# break
# open video A.mp4
cap = cv2.VideoCapture('E.mp4')
while cap.isOpened():
# read frame
ret, frame = cap.read()
if frame is None:
break
pose = extract_keypoints(frame)
buffer.append(pose)
# make_prediction(buffer)
label, score = make_prediction(buffer)
print(label, score)