Added embeddings
This commit is contained in:
BIN
checkpoints/checkpoint_embed_1105.pth
Normal file
BIN
checkpoints/checkpoint_embed_1105.pth
Normal file
Binary file not shown.
1013
embeddings/basic-signs/embeddings.csv
Normal file
1013
embeddings/basic-signs/embeddings.csv
Normal file
File diff suppressed because it is too large
Load Diff
3451
embeddings/fingerspelling/embeddings.csv
Normal file
3451
embeddings/fingerspelling/embeddings.csv
Normal file
File diff suppressed because it is too large
Load Diff
89
webcam.py
89
webcam.py
@@ -1,4 +1,6 @@
|
||||
|
||||
from collections import Counter
|
||||
|
||||
import cv2
|
||||
import mediapipe as mp
|
||||
import numpy as np
|
||||
@@ -192,7 +194,7 @@ def normalize_hand(keypoints):
|
||||
|
||||
|
||||
# load training embedding csv
|
||||
df = pd.read_csv('data/fingerspelling/embeddings.csv')
|
||||
df = pd.read_csv('embeddings/basic-signs/embeddings.csv')
|
||||
|
||||
def minkowski_distance_p(x, y, p=2):
|
||||
x = np.asarray(x)
|
||||
@@ -248,7 +250,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
|
||||
return result
|
||||
|
||||
|
||||
CHECKPOINT_PATH = "out_checkpoints/checkpoint_embed_1105.pth"
|
||||
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
@@ -281,51 +283,60 @@ def make_prediction(keypoints):
|
||||
# calculate distance matrix
|
||||
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
|
||||
|
||||
# find closest match
|
||||
closest_match = np.argmin(dist_matrix[0])
|
||||
# get the 5 closest matches and select the class that is most common and use the average distance as the score
|
||||
# get the 5 closest matches
|
||||
indeces = np.argsort(dist_matrix)[0][:5]
|
||||
# get the labels
|
||||
labels = df["label_name"].iloc[indeces].tolist()
|
||||
c = Counter(labels).most_common()[0][0]
|
||||
|
||||
# if dist_matrix[0][closest_match] < 2:
|
||||
return df.iloc[closest_match]["label_name"], dist_matrix[0][closest_match]
|
||||
# filter indeces to only include the most common label
|
||||
indeces = [i for i in indeces if df["label_name"].iloc[i] == c]
|
||||
# get the average distance
|
||||
score = np.mean(dist_matrix[0][indeces])
|
||||
|
||||
return c, score
|
||||
|
||||
# open webcam stream
|
||||
cap = cv2.VideoCapture(0)
|
||||
# cap = cv2.VideoCapture(0)
|
||||
|
||||
while cap.isOpened():
|
||||
# read frame
|
||||
ret, frame = cap.read()
|
||||
pose = extract_keypoints(frame)
|
||||
|
||||
if pose is None:
|
||||
cv2.imshow('MediaPipe Hands', frame)
|
||||
continue
|
||||
|
||||
buffer.append(pose)
|
||||
if len(buffer) > 15:
|
||||
buffer.pop(0)
|
||||
|
||||
if len(buffer) == 15:
|
||||
label, score = make_prediction(buffer)
|
||||
|
||||
# draw label
|
||||
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
|
||||
# Show the frame
|
||||
cv2.imshow('MediaPipe Hands', frame)
|
||||
|
||||
# Wait for key press to exit
|
||||
if cv2.waitKey(5) & 0xFF == 27:
|
||||
break
|
||||
|
||||
# open video A.mp4
|
||||
# cap = cv2.VideoCapture('Z.mp4')
|
||||
# while cap.isOpened():
|
||||
# # read frame
|
||||
# ret, frame = cap.read()
|
||||
# if frame is None:
|
||||
# break
|
||||
# pose = extract_keypoints(frame)
|
||||
|
||||
# if pose is None:
|
||||
# cv2.imshow('MediaPipe Hands', frame)
|
||||
# continue
|
||||
|
||||
# buffer.append(pose)
|
||||
# if len(buffer) > 15:
|
||||
# buffer.pop(0)
|
||||
|
||||
# if len(buffer) == 15:
|
||||
# label, score = make_prediction(buffer)
|
||||
|
||||
# # draw label
|
||||
# cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
# cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||
|
||||
# # Show the frame
|
||||
# cv2.imshow('MediaPipe Hands', frame)
|
||||
|
||||
# # Wait for key press to exit
|
||||
# if cv2.waitKey(5) & 0xFF == 27:
|
||||
# break
|
||||
|
||||
# open video A.mp4
|
||||
cap = cv2.VideoCapture('E.mp4')
|
||||
while cap.isOpened():
|
||||
# read frame
|
||||
ret, frame = cap.read()
|
||||
if frame is None:
|
||||
break
|
||||
pose = extract_keypoints(frame)
|
||||
|
||||
buffer.append(pose)
|
||||
|
||||
# make_prediction(buffer)
|
||||
label, score = make_prediction(buffer)
|
||||
print(label, score)
|
||||
Reference in New Issue
Block a user