Created guide and script to export embeddings

This commit is contained in:
2023-04-14 14:40:05 +00:00
parent 49ced1983d
commit 1f24df1b8f
6 changed files with 255 additions and 150 deletions

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "c20f7fd5",
"metadata": {},
"outputs": [],
@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "ada032d0",
"metadata": {},
"outputs": [],
@@ -22,13 +22,12 @@
"import os\n",
"import os.path as op\n",
"import pandas as pd\n",
"import json\n",
"import base64"
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"id": "05682e73",
"metadata": {},
"outputs": [],
@@ -38,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"id": "fede7684",
"metadata": {},
"outputs": [],
@@ -48,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "ce531994",
"metadata": {},
"outputs": [],
@@ -64,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"id": "f4a2d672",
"metadata": {},
"outputs": [],
@@ -87,17 +86,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "1d9db764",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fb050be7710>"
"<torch._C.Generator at 0x7f010919d710>"
]
},
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@@ -119,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"id": "71224139",
"metadata": {},
"outputs": [],
@@ -133,7 +132,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 13,
"id": "013d3774",
"metadata": {},
"outputs": [
@@ -143,7 +142,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 9,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -169,27 +168,28 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 24,
"id": "ba6b58f0",
"metadata": {},
"outputs": [],
"source": [
"SL_DATASET = 'wlasl' # or 'lsa'\n",
"if SL_DATASET == 'wlasl':\n",
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
"\n",
"if SL_DATASET == 'fingerspelling':\n",
" dataset_name = \"fingerspelling\"\n",
" num_classes = 100\n",
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
"else:\n",
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
" num_classes = 64\n",
" split_dataset_path = \"LSA64_{}.csv\"\n",
" \n",
"elif SL_DATASET == 'wlasl':\n",
" dataset_name = \"wlasl\"\n",
" split_dataset_path = \"WLASL100_{}.csv\"\n",
"elif SL_DATASET == 'basic-signs':\n",
" dataset_name = \"basic-signs\"\n",
" split_dataset_path = \"basic-signs_{}.csv\"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 25,
"id": "5643a72c",
"metadata": {},
"outputs": [],
@@ -209,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 16,
"id": "04a62088",
"metadata": {},
"outputs": [],
@@ -222,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 17,
"id": "79c837c1",
"metadata": {},
"outputs": [],
@@ -253,7 +253,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 18,
"id": "8b5bda73",
"metadata": {},
"outputs": [],
@@ -280,17 +280,17 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 19,
"id": "0efa0871",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(560, 560)"
"(164, 164)"
]
},
"execution_count": 88,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -301,7 +301,21 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 21,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"for split in splits:\n",
" df = dfs[split]\n",
" df['embeddings'] = embeddings_split[split]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0b9fb9c2",
"metadata": {},
"outputs": [
@@ -309,54 +323,42 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0 [0.4734516, -0.58630264, 0.18397862, -0.165259...\n",
"1 [1.6672437, -2.3754091, -0.77506787, -0.666019...\n",
"2 [1.7801772, -0.0077665895, 0.22098881, 0.09736...\n",
"3 [-0.6503094, 0.14683367, 0.1253598, 0.5183654,...\n",
"4 [1.2275296, -0.4874984, 0.56826925, -0.9628880...\n",
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
" ... \n",
"555 [-0.4408903, -0.9623146, 0.21583065, -0.381131...\n",
"556 [1.7910445, -3.5434258, -1.332628, -0.95276725...\n",
"557 [2.3283613, 0.11504881, -0.4955331, -0.4563401...\n",
"558 [-1.0491562, -1.1793315, 0.3248821, 0.16679825...\n",
"559 [1.447621, -1.2482919, 0.17936605, -1.4752473,...\n",
"Name: embeddings, Length: 560, dtype: object\n",
"0 B\n",
"1 D\n",
"2 X\n",
"3 O\n",
"4 W\n",
" ..\n",
"555 F\n",
"556 X\n",
"557 Z\n",
"558 Y\n",
"559 W\n",
"Name: label_name, Length: 560, dtype: object\n",
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
"Name: embeddings, Length: 164, dtype: object\n",
"0 TOT-ZIENS\n",
"1 GOED\n",
"2 GOEDENACHT\n",
"3 NEE\n",
"4 SLECHT\n",
" ... \n",
"159 SORRY\n",
"160 GOEDEMORGEN\n",
"161 LINKS\n",
"162 TOT-ZIENS\n",
"163 GOED\n",
"Name: label_name, Length: 164, dtype: object\n",
"0 0\n",
"1 1\n",
"2 2\n",
"3 3\n",
"4 5\n",
"4 4\n",
" ..\n",
"555 24\n",
"556 2\n",
"557 14\n",
"558 8\n",
"559 5\n",
"Name: labels, Length: 560, dtype: int64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_969762/1944871806.py:9: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n"
"159 7\n",
"160 5\n",
"161 13\n",
"162 0\n",
"163 1\n",
"Name: labels, Length: 164, dtype: int64\n"
]
}
],
@@ -372,21 +374,7 @@
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
"\n",
"# save the dfs['train']\n",
"dfs['train'].to_csv('../data/fingerspelling/embeddings.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"for split in splits:\n",
" df = dfs[split]\n",
" df['embeddings'] = embeddings_split[split]"
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
]
},
{
@@ -400,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 23,
"id": "7399b8ae",
"metadata": {},
"outputs": [
@@ -409,16 +397,16 @@
"output_type": "stream",
"text": [
"Using centroids only\n",
"Top-1 accuracy: 77.06 %\n",
"Top-5 embeddings class match: 100.00 % (Picks any class in the 5 closest embeddings)\n",
"Top-1 accuracy: 80.00 %\n",
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
"\n",
"################################\n",
"\n",
"Using all embeddings\n",
"Top-1 accuracy: 81.65 %\n",
"5-nn accuracy: 83.49 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 96.33 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 99.08 % (Picks the 5 closest distinct classes)\n",
"Top-1 accuracy: 80.00 %\n",
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
"\n",
"################################\n",
"\n"

File diff suppressed because one or more lines are too long