basic svm
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
|
||||||
|
# TODO scaling van distance tov intra distances?
|
||||||
|
# TODO efficientere manier om k=1 te doen
|
||||||
|
|
||||||
|
|
||||||
def minkowski_distance_p(x, y, p=2):
|
def minkowski_distance_p(x, y, p=2):
|
||||||
x = np.asarray(x)
|
x = np.asarray(x)
|
||||||
|
|||||||
39
predictions/svm_model.py
Normal file
39
predictions/svm_model.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from sklearn import svm
|
||||||
|
|
||||||
|
class SVM:
|
||||||
|
def __init__(self, type="ovo"):
|
||||||
|
self.label_name_to_label = None
|
||||||
|
self.clf = None
|
||||||
|
self.embeddings_list = None
|
||||||
|
self.labels = None
|
||||||
|
self.type = type
|
||||||
|
|
||||||
|
def set_embeddings(self, embeddings):
|
||||||
|
# convert embedding from string to list of floats
|
||||||
|
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
|
||||||
|
# drop embeddings2
|
||||||
|
df = embeddings.drop(columns=['embeddings2'])
|
||||||
|
# to list
|
||||||
|
self.embeddings_list = df["embeddings"].tolist()
|
||||||
|
self.labels = df["labels"].tolist()
|
||||||
|
self.label_name_to_label = df[["label_name", "labels"]]
|
||||||
|
self.label_name_to_label.columns = ["label_name", "label"]
|
||||||
|
self.label_name_to_label = self.label_name_to_label.drop_duplicates()
|
||||||
|
print(self.label_name_to_label)
|
||||||
|
|
||||||
|
self.train()
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
self.clf = svm.SVC(decision_function_shape=self.type, probability=True)
|
||||||
|
self.clf.fit(self.embeddings_list, self.labels)
|
||||||
|
|
||||||
|
def predict(self, key_points_embeddings):
|
||||||
|
label = self.clf.predict(key_points_embeddings)
|
||||||
|
score = self.clf.predict_log_proba(key_points_embeddings)
|
||||||
|
# TODO fix dictionary
|
||||||
|
label = label.item()
|
||||||
|
print("test")
|
||||||
|
print(self.label_name_to_label.loc[self.label_name_to_label["label"] == label]["label_name"].iloc[0])
|
||||||
|
print("test2")
|
||||||
|
print(score)
|
||||||
|
return self.label_name_to_label.loc[self.label_name_to_label["label"] == label]["label_name"].iloc[0], score[0][label]
|
||||||
16
webcam.py
16
webcam.py
@@ -2,15 +2,23 @@ import cv2
|
|||||||
|
|
||||||
from predictions.k_nearest import KNearestNeighbours
|
from predictions.k_nearest import KNearestNeighbours
|
||||||
from predictions.predictor import Predictor
|
from predictions.predictor import Predictor
|
||||||
|
from predictions.svm_model import SVM
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
buffer = []
|
buffer = []
|
||||||
|
|
||||||
# open webcam stream
|
# open webcam stream
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
|
|
||||||
k = 3
|
type_predictor = "svm"
|
||||||
predictor_type = KNearestNeighbours(k)
|
if type_predictor == "knn":
|
||||||
|
k = 10
|
||||||
|
predictor_type = KNearestNeighbours(k)
|
||||||
|
elif type_predictor == "svm":
|
||||||
|
predictor_type = SVM()
|
||||||
|
else:
|
||||||
|
predictor_type = KNearestNeighbours(1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
|
# embeddings_path = 'embeddings/basic-signs/embeddings.csv'
|
||||||
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
|
embeddings_path = 'embeddings/fingerspelling/embeddings.csv'
|
||||||
@@ -39,7 +47,7 @@ if __name__ == '__main__':
|
|||||||
label, score = predictor.make_prediction(buffer)
|
label, score = predictor.make_prediction(buffer)
|
||||||
|
|
||||||
# draw label
|
# draw label
|
||||||
cv2.putText(frame, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
cv2.putText(frame, str(label), (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||||
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
cv2.putText(frame, str(score), (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA)
|
||||||
|
|
||||||
# Show the frame
|
# Show the frame
|
||||||
|
|||||||
Reference in New Issue
Block a user