Basic version to compare embeddings and the last level opf prediction
This commit is contained in:
86
predictions/plotting.py
Normal file
86
predictions/plotting.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def load_results():
|
||||
with open("predictions/test_results/knn.json", 'r') as f:
|
||||
results = json.load(f)
|
||||
return results
|
||||
|
||||
def plot_all():
|
||||
results = load_results()
|
||||
print(f"average elapsed time to detect a sign: {get_general_elapsed_time(results)}")
|
||||
plot_general_accuracy(results)
|
||||
for label in results.keys():
|
||||
plot_accuracy_per_label(results, label)
|
||||
|
||||
|
||||
def general_accuracy(results):
|
||||
label_accuracy = get_label_accuracy(results)
|
||||
accuracy = []
|
||||
amount = []
|
||||
response = []
|
||||
for label in label_accuracy.keys():
|
||||
for index, value in enumerate(label_accuracy[label]):
|
||||
if index >= len(accuracy):
|
||||
accuracy.append(0)
|
||||
amount.append(0)
|
||||
accuracy[index] += label_accuracy[label][index]
|
||||
amount[index] += 1
|
||||
for a, b in zip(accuracy, amount):
|
||||
if b < 5:
|
||||
break
|
||||
response.append(a / b)
|
||||
return response
|
||||
def plot_general_accuracy(results):
|
||||
accuracy = general_accuracy(results)
|
||||
plt.plot(accuracy)
|
||||
plt.title = "General accuracy"
|
||||
plt.ylabel('accuracy')
|
||||
plt.xlabel('buffer')
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_accuracy_per_label(results, label):
|
||||
accuracy = get_label_accuracy(results)
|
||||
plt.plot(accuracy[label], label=label)
|
||||
plt.titel = f"Accuracy per label {label}"
|
||||
plt.ylabel('accuracy')
|
||||
plt.xlabel('prediction')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
def get_label_accuracy(results):
|
||||
accuracy = {}
|
||||
amount = {}
|
||||
response = {}
|
||||
for label, predictions in results.items():
|
||||
if label not in accuracy:
|
||||
accuracy[label] = []
|
||||
amount[label] = []
|
||||
for prediction in predictions:
|
||||
for index, value in enumerate(prediction["predictions"]):
|
||||
if index >= len(accuracy[label]):
|
||||
accuracy[label].append(0)
|
||||
amount[label].append(0)
|
||||
accuracy[label][index] += 1 if value["correct"] else 0
|
||||
amount[label][index] += 1
|
||||
for label in accuracy:
|
||||
response[label] = []
|
||||
for index, value in enumerate(accuracy[label]):
|
||||
if amount[label][index] < 2:
|
||||
break
|
||||
response[label].append(accuracy[label][index] / amount[label][index])
|
||||
return response
|
||||
|
||||
def get_general_elapsed_time(results):
|
||||
label_time = get_label_elapsed_time(results)
|
||||
return sum([label_time[label] for label in results]) / len(results)
|
||||
|
||||
def get_label_elapsed_time(results):
|
||||
return {label: sum([result["elapsed_time"] for result in results[label]]) / len(results[label]) for label in results}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plot_all()
|
||||
Reference in New Issue
Block a user