Added embeddings
This commit is contained in:
BIN
checkpoints/checkpoint_embed_1105.pth
Normal file
BIN
checkpoints/checkpoint_embed_1105.pth
Normal file
Binary file not shown.
1013
embeddings/basic-signs/embeddings.csv
Normal file
1013
embeddings/basic-signs/embeddings.csv
Normal file
File diff suppressed because it is too large
Load Diff
3451
embeddings/fingerspelling/embeddings.csv
Normal file
3451
embeddings/fingerspelling/embeddings.csv
Normal file
File diff suppressed because it is too large
Load Diff
91
webcam.py
91
webcam.py
@@ -1,4 +1,6 @@
|
|||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
import mediapipe as mp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -192,7 +194,7 @@ def normalize_hand(keypoints):
|
|||||||
|
|
||||||
|
|
||||||
# load training embedding csv
|
# load training embedding csv
|
||||||
df = pd.read_csv('data/fingerspelling/embeddings.csv')
|
df = pd.read_csv('embeddings/basic-signs/embeddings.csv')
|
||||||
|
|
||||||
def minkowski_distance_p(x, y, p=2):
|
def minkowski_distance_p(x, y, p=2):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
@@ -248,7 +250,7 @@ def distance_matrix(keypoints, embeddings, p=2, threshold=1000000):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
CHECKPOINT_PATH = "out_checkpoints/checkpoint_embed_1105.pth"
|
CHECKPOINT_PATH = "checkpoints/checkpoint_embed_1105.pth"
|
||||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||||
|
|
||||||
model = SPOTER_EMBEDDINGS(
|
model = SPOTER_EMBEDDINGS(
|
||||||
@@ -281,51 +283,60 @@ def make_prediction(keypoints):
|
|||||||
# calculate distance matrix
|
# calculate distance matrix
|
||||||
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
|
dist_matrix = distance_matrix(new_embeddings, embeddings, p=2, threshold=1000000)
|
||||||
|
|
||||||
# find closest match
|
# get the 5 closest matches and select the class that is most common and use the average distance as the score
|
||||||
closest_match = np.argmin(dist_matrix[0])
|
# get the 5 closest matches
|
||||||
|
indeces = np.argsort(dist_matrix)[0][:5]
|
||||||
|
# get the labels
|
||||||
|
labels = df["label_name"].iloc[indeces].tolist()
|
||||||
|
c = Counter(labels).most_common()[0][0]
|
||||||
|
|
||||||
# if dist_matrix[0][closest_match] < 2:
|
# filter indeces to only include the most common label
|
||||||
return df.iloc[closest_match]["label_name"], dist_matrix[0][closest_match]
|
indeces = [i for i in indeces if df["label_name"].iloc[i] == c]
|
||||||
|
# get the average distance
|
||||||
|
score = np.mean(dist_matrix[0][indeces])
|
||||||
|
|
||||||
|
return c, score
|
||||||
|
|
||||||
# open webcam stream
|
# open webcam stream
|
||||||
cap = cv2.VideoCapture(0)
|
# cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
while cap.isOpened():
|
|
||||||
# read frame
|
|
||||||
ret, frame = cap.read()
|
|
||||||
pose = extract_keypoints(frame)
|
|
||||||
|
|
||||||
if pose is None:
|
|
||||||
cv2.imshow('MediaPipe Hands', frame)
|
|
||||||
continue
|
|
||||||
|
|
||||||
buffer.append(pose)
|
|
||||||
if len(buffer) > 15:
|
|
||||||
buffer.pop(0)
|
|
||||||
|
|
||||||
if len(buffer) == 15:
|
|
||||||
label, score = make_prediction(buffer)
|
|
||||||
|
|
||||||
# draw label
|
|
||||||
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
|
||||||
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
|
||||||
|
|
||||||
# Show the frame
|
|
||||||
cv2.imshow('MediaPipe Hands', frame)
|
|
||||||
|
|
||||||
# Wait for key press to exit
|
|
||||||
if cv2.waitKey(5) & 0xFF == 27:
|
|
||||||
break
|
|
||||||
|
|
||||||
# open video A.mp4
|
|
||||||
# cap = cv2.VideoCapture('Z.mp4')
|
|
||||||
# while cap.isOpened():
|
# while cap.isOpened():
|
||||||
# # read frame
|
# # read frame
|
||||||
# ret, frame = cap.read()
|
# ret, frame = cap.read()
|
||||||
# if frame is None:
|
|
||||||
# break
|
|
||||||
# pose = extract_keypoints(frame)
|
# pose = extract_keypoints(frame)
|
||||||
|
|
||||||
# buffer.append(pose)
|
# if pose is None:
|
||||||
|
# cv2.imshow('MediaPipe Hands', frame)
|
||||||
|
# continue
|
||||||
|
|
||||||
# make_prediction(buffer)
|
# buffer.append(pose)
|
||||||
|
# if len(buffer) > 15:
|
||||||
|
# buffer.pop(0)
|
||||||
|
|
||||||
|
# if len(buffer) == 15:
|
||||||
|
# label, score = make_prediction(buffer)
|
||||||
|
|
||||||
|
# # draw label
|
||||||
|
# cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||||
|
# cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||||
|
|
||||||
|
# # Show the frame
|
||||||
|
# cv2.imshow('MediaPipe Hands', frame)
|
||||||
|
|
||||||
|
# # Wait for key press to exit
|
||||||
|
# if cv2.waitKey(5) & 0xFF == 27:
|
||||||
|
# break
|
||||||
|
|
||||||
|
# open video A.mp4
|
||||||
|
cap = cv2.VideoCapture('E.mp4')
|
||||||
|
while cap.isOpened():
|
||||||
|
# read frame
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if frame is None:
|
||||||
|
break
|
||||||
|
pose = extract_keypoints(frame)
|
||||||
|
|
||||||
|
buffer.append(pose)
|
||||||
|
|
||||||
|
label, score = make_prediction(buffer)
|
||||||
|
print(label, score)
|
||||||
Reference in New Issue
Block a user