Files
spoterembedding/notebooks/embeddings_evaluation.ipynb
Mathias Claassen 81bbf66aab Initial codebase (#1)
* Add project code

* Logger improvements

* Improvements to web demo code

* added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py

* Fix rotation augmentation

* fixed error in docstring, and removed unnecessary replace -1 -> 0

* Readme updates

* Share base notebooks

* Add notebooks and unify for different datasets

* requirements update

* fixes

* Make evaluate more deterministic

* Allow training with clearml

* refactor preprocessing and apply linter

* Minor fixes

* Minor notebook tweaks

* Readme updates

* Fix PR comments

* Remove unneeded code

* Add banner to Readme

---------

Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
2023-03-03 10:07:54 -03:00

412 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "c20f7fd5",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ada032d0",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import os.path as op\n",
"import pandas as pd\n",
"import json\n",
"import base64"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "05682e73",
"metadata": {},
"outputs": [],
"source": [
"sys.path.append(op.abspath('..'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fede7684",
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ce531994",
"metadata": {},
"outputs": [],
"source": [
"from collections import Counter\n",
"from itertools import chain\n",
"\n",
"import torch\n",
"import multiprocessing\n",
"from scipy.spatial import distance_matrix\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f4a2d672",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"from datasets import SLREmbeddingDataset, collate_fn_padd\n",
"from datasets.dataset_loader import LocalDatasetLoader\n",
"from models import embeddings_scatter_plot_splits\n",
"from models import SPOTER_EMBEDDINGS"
]
},
{
"cell_type": "markdown",
"id": "af8fbe32",
"metadata": {},
"source": [
"## Model and dataset loading"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d9db764",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"seed = 43\n",
"random.seed(seed)\n",
"np.random.seed(seed)\n",
"os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"torch.cuda.manual_seed_all(seed)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.use_deterministic_algorithms(True) \n",
"generator = torch.Generator()\n",
"generator.manual_seed(seed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "71224139",
"metadata": {},
"outputs": [],
"source": [
"BASE_DATA_FOLDER = '../data/'\n",
"os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "013d3774",
"metadata": {},
"outputs": [],
"source": [
"# LOAD MODEL FROM CLEARML\n",
"# from clearml import InputModel\n",
"# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n",
"# checkpoint = torch.load(model.get_weights())\n",
"\n",
"## Set your path to checkoint here\n",
"CHECKPOINT_PATH = \"../checkpoints/checkpoint_embed_992.pth\"\n",
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
"\n",
"model = SPOTER_EMBEDDINGS(\n",
" features=checkpoint[\"config_args\"].vector_length,\n",
" hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n",
" norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n",
").to(device)\n",
"\n",
"model.load_state_dict(checkpoint[\"state_dict\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ba6b58f0",
"metadata": {},
"outputs": [],
"source": [
"SL_DATASET = 'wlasl' # or 'lsa'\n",
"if SL_DATASET == 'wlasl':\n",
" dataset_name = \"wlasl_mapped_mediapipe_only_landmarks_25fps\"\n",
" num_classes = 100\n",
" split_dataset_path = \"WLASL100_{}_25fps.csv\"\n",
"else:\n",
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
" num_classes = 64\n",
" split_dataset_path = \"LSA64_{}.csv\"\n",
" \n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5643a72c",
"metadata": {},
"outputs": [],
"source": [
"def get_dataset_loader(loader_name=None):\n",
" if loader_name == 'CLEARML':\n",
" from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n",
" return ClearMLDatasetLoader()\n",
" else:\n",
" return LocalDatasetLoader()\n",
"\n",
"dataset_loader = get_dataset_loader()\n",
"dataset_project = \"Sign Language Recognition\"\n",
"batch_size = 1\n",
"dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04a62088",
"metadata": {},
"outputs": [],
"source": [
"def seed_worker(worker_id):\n",
" worker_seed = torch.initial_seed() % 2**32\n",
" np.random.seed(worker_seed)\n",
" random.seed(worker_seed)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79c837c1",
"metadata": {},
"outputs": [],
"source": [
"dataloaders = {}\n",
"splits = ['train', 'val']\n",
"dfs = {}\n",
"for split in splits:\n",
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
" split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)\n",
" data_loader = DataLoader(\n",
" split_set,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" collate_fn=collate_fn_padd,\n",
" pin_memory=torch.cuda.is_available(),\n",
" num_workers=multiprocessing.cpu_count(),\n",
" worker_init_fn=seed_worker,\n",
" generator=generator,\n",
" )\n",
" dataloaders[split] = data_loader\n",
" dfs[split] = pd.read_csv(split_set_path)\n",
"\n",
"with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n",
" id_to_label = json.load(fid)\n",
"id_to_label = {int(key): value for key, value in id_to_label.items()}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b5bda73",
"metadata": {},
"outputs": [],
"source": [
"labels_split = {}\n",
"embeddings_split = {}\n",
"splits = list(dataloaders.keys())\n",
"with torch.no_grad():\n",
" for split, dataloader in dataloaders.items():\n",
" labels_str = []\n",
" embeddings = []\n",
" k = 0\n",
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
" k += 1\n",
" inputs = inputs.to(device)\n",
" masks = masks.to(device)\n",
" outputs = model(inputs, masks)\n",
" for n in range(outputs.shape[0]):\n",
" embeddings.append(outputs[n, 0].cpu().detach().numpy())\n",
" embeddings_split[split] = embeddings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0efa0871",
"metadata": {},
"outputs": [],
"source": [
"len(embeddings_split['train']), len(dfs['train'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"for split in splits:\n",
" df = dfs[split]\n",
" df['embeddings'] = embeddings_split[split]"
]
},
{
"cell_type": "markdown",
"id": "2951638d",
"metadata": {},
"source": [
"## Compute metrics\n",
"Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7399b8ae",
"metadata": {},
"outputs": [],
"source": [
"for use_centroids, str_use_centroids in zip([True, False],\n",
" ['Using centroids only', 'Using all embeddings']):\n",
"\n",
" df_val = dfs['val']\n",
" df_train = dfs['train']\n",
" if use_centroids:\n",
" df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()\n",
" x_train = np.vstack(df_train['embeddings'])\n",
" x_val = np.vstack(df_val['embeddings'])\n",
"\n",
" d_mat = distance_matrix(x_val, x_train, p=2)\n",
"\n",
" top5_embs = 0\n",
" top5_classes = 0\n",
" knn = 0\n",
" top1 = 0\n",
"\n",
" len_val_dataset = len(df_val)\n",
" good_samples = []\n",
"\n",
" for i in range(d_mat.shape[0]):\n",
" true_label = df_val.loc[i, 'labels']\n",
" labels = df_train['labels'].values\n",
" argsort = np.argsort(d_mat[i])\n",
" sorted_labels = labels[argsort]\n",
" if sorted_labels[0] == true_label:\n",
" top1 += 1\n",
" if use_centroids:\n",
" good_samples.append(df_val.loc[i, 'video_id'])\n",
" else:\n",
" good_samples.append((df_val.loc[i, 'video_id'],\n",
" df_train.loc[argsort[0], 'video_id'],\n",
" i,\n",
" argsort[0]))\n",
"\n",
"\n",
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
" knn += 1\n",
" if true_label in sorted_labels[:5]:\n",
" top5_embs += 1\n",
" if true_label in list(dict.fromkeys(sorted_labels))[:5]:\n",
" top5_classes += 1\n",
" else:\n",
" continue\n",
"\n",
"\n",
" print(str_use_centroids)\n",
"\n",
"\n",
" print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')\n",
" if not use_centroids:\n",
" print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')\n",
" print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} % (Picks any class in the 5 closest embeddings)')\n",
" if not use_centroids:\n",
" print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} % (Picks the 5 closest distinct classes)')\n",
" print('\\n' + '#'*32 + '\\n')"
]
},
{
"cell_type": "markdown",
"id": "d2aaac6c",
"metadata": {},
"source": [
"## Show some examples (only for WLASL)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9d1d309",
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd2a0cd8",
"metadata": {},
"outputs": [],
"source": [
"for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():\n",
" display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))"
]
}
],
"metadata": {
"jupytext": {
"cell_metadata_filter": "-all",
"main_language": "python",
"notebook_metadata_filter": "-all"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}