Fingerspelling embedding + ClearML
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 5,
|
||||
"id": "c20f7fd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 6,
|
||||
"id": "ada032d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -22,13 +22,12 @@
|
||||
"import os\n",
|
||||
"import os.path as op\n",
|
||||
"import pandas as pd\n",
|
||||
"import json\n",
|
||||
"import base64"
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 7,
|
||||
"id": "05682e73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -38,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 8,
|
||||
"id": "fede7684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -48,7 +47,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 9,
|
||||
"id": "ce531994",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -64,7 +63,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 10,
|
||||
"id": "f4a2d672",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -87,17 +86,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"id": "1d9db764",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch._C.Generator at 0x7f29f89e3ed0>"
|
||||
"<torch._C.Generator at 0x7f010919d710>"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -119,7 +118,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 12,
|
||||
"id": "71224139",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -155,7 +154,7 @@
|
||||
"# checkpoint = torch.load(model.get_weights())\n",
|
||||
"\n",
|
||||
"## Set your path to checkoint here\n",
|
||||
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_6.pth\"\n",
|
||||
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n",
|
||||
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
||||
"\n",
|
||||
"model = SPOTER_EMBEDDINGS(\n",
|
||||
@@ -169,27 +168,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 24,
|
||||
"id": "ba6b58f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
||||
"if SL_DATASET == 'wlasl':\n",
|
||||
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
|
||||
"\n",
|
||||
"if SL_DATASET == 'fingerspelling':\n",
|
||||
" dataset_name = \"fingerspelling\"\n",
|
||||
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
|
||||
"elif SL_DATASET == 'wlasl':\n",
|
||||
" dataset_name = \"wlasl\"\n",
|
||||
" num_classes = 100\n",
|
||||
" split_dataset_path = \"WLASL100_train.csv\"\n",
|
||||
"else:\n",
|
||||
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
||||
" num_classes = 64\n",
|
||||
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
||||
" \n",
|
||||
" split_dataset_path = \"WLASL100_{}.csv\"\n",
|
||||
"elif SL_DATASET == 'basic-signs':\n",
|
||||
" dataset_name = \"basic-signs\"\n",
|
||||
" split_dataset_path = \"basic-signs_{}.csv\"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 25,
|
||||
"id": "5643a72c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -228,7 +228,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataloaders = {}\n",
|
||||
"splits = ['train', 'val']\n",
|
||||
"splits = ['train', 'val']\n",
|
||||
"dfs = {}\n",
|
||||
"for split in splits:\n",
|
||||
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
||||
@@ -269,6 +269,8 @@
|
||||
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
||||
" k += 1\n",
|
||||
" inputs = inputs.to(device)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" masks = masks.to(device)\n",
|
||||
" outputs = model(inputs, masks)\n",
|
||||
" for n in range(outputs.shape[0]):\n",
|
||||
@@ -285,7 +287,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(810, 810)"
|
||||
"(164, 164)"
|
||||
]
|
||||
},
|
||||
"execution_count": 19,
|
||||
@@ -299,7 +301,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 21,
|
||||
"id": "ab83c6e2",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
@@ -311,6 +313,70 @@
|
||||
" df['embeddings'] = embeddings_split[split]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "0b9fb9c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
|
||||
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
|
||||
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
|
||||
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
|
||||
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
|
||||
" ... \n",
|
||||
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
|
||||
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
|
||||
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
|
||||
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
|
||||
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
|
||||
"Name: embeddings, Length: 164, dtype: object\n",
|
||||
"0 TOT-ZIENS\n",
|
||||
"1 GOED\n",
|
||||
"2 GOEDENACHT\n",
|
||||
"3 NEE\n",
|
||||
"4 SLECHT\n",
|
||||
" ... \n",
|
||||
"159 SORRY\n",
|
||||
"160 GOEDEMORGEN\n",
|
||||
"161 LINKS\n",
|
||||
"162 TOT-ZIENS\n",
|
||||
"163 GOED\n",
|
||||
"Name: label_name, Length: 164, dtype: object\n",
|
||||
"0 0\n",
|
||||
"1 1\n",
|
||||
"2 2\n",
|
||||
"3 3\n",
|
||||
"4 4\n",
|
||||
" ..\n",
|
||||
"159 7\n",
|
||||
"160 5\n",
|
||||
"161 13\n",
|
||||
"162 0\n",
|
||||
"163 1\n",
|
||||
"Name: labels, Length: 164, dtype: int64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(dfs['train'][\"embeddings\"])\n",
|
||||
"print(dfs['train'][\"label_name\"])\n",
|
||||
"print(dfs['train'][\"labels\"])\n",
|
||||
"\n",
|
||||
"# only keep these columns\n",
|
||||
"dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n",
|
||||
"\n",
|
||||
"# convert embeddings to string\n",
|
||||
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
|
||||
"\n",
|
||||
"# save the dfs['train']\n",
|
||||
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2951638d",
|
||||
@@ -322,7 +388,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 23,
|
||||
"id": "7399b8ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -331,16 +397,16 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using centroids only\n",
|
||||
"Top-1 accuracy: 5.19 %\n",
|
||||
"Top-5 embeddings class match: 17.65 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-1 accuracy: 80.00 %\n",
|
||||
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"\n",
|
||||
"################################\n",
|
||||
"\n",
|
||||
"Using all embeddings\n",
|
||||
"Top-1 accuracy: 5.31 %\n",
|
||||
"5-nn accuracy: 5.56 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||
"Top-5 embeddings class match: 15.43 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-5 unique class match: 15.56 % (Picks the 5 closest distinct classes)\n",
|
||||
"Top-1 accuracy: 80.00 %\n",
|
||||
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
|
||||
"\n",
|
||||
"################################\n",
|
||||
"\n"
|
||||
@@ -375,13 +441,13 @@
|
||||
" sorted_labels = labels[argsort]\n",
|
||||
" if sorted_labels[0] == true_label:\n",
|
||||
" top1 += 1\n",
|
||||
" if use_centroids:\n",
|
||||
" good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||
" else:\n",
|
||||
" good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||
" df_train.loc[argsort[0], 'video_id'],\n",
|
||||
" i,\n",
|
||||
" argsort[0]))\n",
|
||||
" # if use_centroids:\n",
|
||||
" # good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||
" # else:\n",
|
||||
" # good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||
" # df_train.loc[argsort[0], 'video_id'],\n",
|
||||
" # i,\n",
|
||||
" # argsort[0]))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
||||
|
||||
Reference in New Issue
Block a user