Fixed some processing

This commit is contained in:
2023-04-14 14:35:53 +02:00
parent fd4faa45b2
commit 2cbf11eb00

View File

@@ -260,24 +260,20 @@ model = SPOTER_EMBEDDINGS(
).to(device)
model.load_state_dict(checkpoint["state_dict"])
embeddings = df.drop(columns=['labels', 'label_name', 'embeddings'])
# convert embedding from string to list of floats
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
# drop embeddings2
embeddings = embeddings.drop(columns=['embeddings2'])
# to list
embeddings = embeddings["embeddings"].tolist()
def make_prediction(keypoints):
embeddings = df.drop(columns=['labels', 'label_name', 'embeddings'])
# convert embedding from string to list of floats
embeddings["embeddings"] = embeddings["embeddings2"].apply(lambda x: [float(i) for i in x[1:-1].split(", ")])
# drop embeddings2
embeddings = embeddings.drop(columns=['embeddings2'])
# to list
embeddings = embeddings["embeddings"].tolist()
# run model on frame
model.eval()
with torch.no_grad():
keypoints = torch.from_numpy(np.array([keypoints])).float().to(device)
with open('inputs.txt', 'w') as f:
for j in range(keypoints.shape[1]):
f.write(str(keypoints[0, j, :].cpu().detach().numpy()) + ' ')
new_embeddings = model(keypoints).cpu().numpy().tolist()[0]
# calculate distance matrix