* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
412 lines
12 KiB
Plaintext
412 lines
12 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c20f7fd5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ada032d0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import sys\n",
|
|
"import os\n",
|
|
"import os.path as op\n",
|
|
"import pandas as pd\n",
|
|
"import json\n",
|
|
"import base64"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "05682e73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"sys.path.append(op.abspath('..'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fede7684",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ce531994",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from collections import Counter\n",
|
|
"from itertools import chain\n",
|
|
"\n",
|
|
"import torch\n",
|
|
"import multiprocessing\n",
|
|
"from scipy.spatial import distance_matrix\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f4a2d672",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch.utils.data import DataLoader\n",
|
|
"\n",
|
|
"from datasets import SLREmbeddingDataset, collate_fn_padd\n",
|
|
"from datasets.dataset_loader import LocalDatasetLoader\n",
|
|
"from models import embeddings_scatter_plot_splits\n",
|
|
"from models import SPOTER_EMBEDDINGS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "af8fbe32",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Model and dataset loading"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1d9db764",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"seed = 43\n",
|
|
"random.seed(seed)\n",
|
|
"np.random.seed(seed)\n",
|
|
"os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed_all(seed)\n",
|
|
"torch.backends.cudnn.deterministic = True\n",
|
|
"torch.use_deterministic_algorithms(True) \n",
|
|
"generator = torch.Generator()\n",
|
|
"generator.manual_seed(seed)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "71224139",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BASE_DATA_FOLDER = '../data/'\n",
|
|
"os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n",
|
|
"device = torch.device(\"cpu\")\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" device = torch.device(\"cuda\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "013d3774",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# LOAD MODEL FROM CLEARML\n",
|
|
"# from clearml import InputModel\n",
|
|
"# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n",
|
|
"# checkpoint = torch.load(model.get_weights())\n",
|
|
"\n",
|
|
"## Set your path to checkoint here\n",
|
|
"CHECKPOINT_PATH = \"../checkpoints/checkpoint_embed_992.pth\"\n",
|
|
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
|
"\n",
|
|
"model = SPOTER_EMBEDDINGS(\n",
|
|
" features=checkpoint[\"config_args\"].vector_length,\n",
|
|
" hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n",
|
|
" norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"model.load_state_dict(checkpoint[\"state_dict\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ba6b58f0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
|
"if SL_DATASET == 'wlasl':\n",
|
|
" dataset_name = \"wlasl_mapped_mediapipe_only_landmarks_25fps\"\n",
|
|
" num_classes = 100\n",
|
|
" split_dataset_path = \"WLASL100_{}_25fps.csv\"\n",
|
|
"else:\n",
|
|
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
|
" num_classes = 64\n",
|
|
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
|
" \n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5643a72c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_dataset_loader(loader_name=None):\n",
|
|
" if loader_name == 'CLEARML':\n",
|
|
" from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n",
|
|
" return ClearMLDatasetLoader()\n",
|
|
" else:\n",
|
|
" return LocalDatasetLoader()\n",
|
|
"\n",
|
|
"dataset_loader = get_dataset_loader()\n",
|
|
"dataset_project = \"Sign Language Recognition\"\n",
|
|
"batch_size = 1\n",
|
|
"dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "04a62088",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def seed_worker(worker_id):\n",
|
|
" worker_seed = torch.initial_seed() % 2**32\n",
|
|
" np.random.seed(worker_seed)\n",
|
|
" random.seed(worker_seed)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "79c837c1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dataloaders = {}\n",
|
|
"splits = ['train', 'val']\n",
|
|
"dfs = {}\n",
|
|
"for split in splits:\n",
|
|
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
|
" split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)\n",
|
|
" data_loader = DataLoader(\n",
|
|
" split_set,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=False,\n",
|
|
" collate_fn=collate_fn_padd,\n",
|
|
" pin_memory=torch.cuda.is_available(),\n",
|
|
" num_workers=multiprocessing.cpu_count(),\n",
|
|
" worker_init_fn=seed_worker,\n",
|
|
" generator=generator,\n",
|
|
" )\n",
|
|
" dataloaders[split] = data_loader\n",
|
|
" dfs[split] = pd.read_csv(split_set_path)\n",
|
|
"\n",
|
|
"with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n",
|
|
" id_to_label = json.load(fid)\n",
|
|
"id_to_label = {int(key): value for key, value in id_to_label.items()}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8b5bda73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"labels_split = {}\n",
|
|
"embeddings_split = {}\n",
|
|
"splits = list(dataloaders.keys())\n",
|
|
"with torch.no_grad():\n",
|
|
" for split, dataloader in dataloaders.items():\n",
|
|
" labels_str = []\n",
|
|
" embeddings = []\n",
|
|
" k = 0\n",
|
|
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
|
" k += 1\n",
|
|
" inputs = inputs.to(device)\n",
|
|
" masks = masks.to(device)\n",
|
|
" outputs = model(inputs, masks)\n",
|
|
" for n in range(outputs.shape[0]):\n",
|
|
" embeddings.append(outputs[n, 0].cpu().detach().numpy())\n",
|
|
" embeddings_split[split] = embeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0efa0871",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"len(embeddings_split['train']), len(dfs['train'])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ab83c6e2",
|
|
"metadata": {
|
|
"lines_to_next_cell": 2
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"for split in splits:\n",
|
|
" df = dfs[split]\n",
|
|
" df['embeddings'] = embeddings_split[split]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2951638d",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Compute metrics\n",
|
|
"Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7399b8ae",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for use_centroids, str_use_centroids in zip([True, False],\n",
|
|
" ['Using centroids only', 'Using all embeddings']):\n",
|
|
"\n",
|
|
" df_val = dfs['val']\n",
|
|
" df_train = dfs['train']\n",
|
|
" if use_centroids:\n",
|
|
" df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()\n",
|
|
" x_train = np.vstack(df_train['embeddings'])\n",
|
|
" x_val = np.vstack(df_val['embeddings'])\n",
|
|
"\n",
|
|
" d_mat = distance_matrix(x_val, x_train, p=2)\n",
|
|
"\n",
|
|
" top5_embs = 0\n",
|
|
" top5_classes = 0\n",
|
|
" knn = 0\n",
|
|
" top1 = 0\n",
|
|
"\n",
|
|
" len_val_dataset = len(df_val)\n",
|
|
" good_samples = []\n",
|
|
"\n",
|
|
" for i in range(d_mat.shape[0]):\n",
|
|
" true_label = df_val.loc[i, 'labels']\n",
|
|
" labels = df_train['labels'].values\n",
|
|
" argsort = np.argsort(d_mat[i])\n",
|
|
" sorted_labels = labels[argsort]\n",
|
|
" if sorted_labels[0] == true_label:\n",
|
|
" top1 += 1\n",
|
|
" if use_centroids:\n",
|
|
" good_samples.append(df_val.loc[i, 'video_id'])\n",
|
|
" else:\n",
|
|
" good_samples.append((df_val.loc[i, 'video_id'],\n",
|
|
" df_train.loc[argsort[0], 'video_id'],\n",
|
|
" i,\n",
|
|
" argsort[0]))\n",
|
|
"\n",
|
|
"\n",
|
|
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
|
" knn += 1\n",
|
|
" if true_label in sorted_labels[:5]:\n",
|
|
" top5_embs += 1\n",
|
|
" if true_label in list(dict.fromkeys(sorted_labels))[:5]:\n",
|
|
" top5_classes += 1\n",
|
|
" else:\n",
|
|
" continue\n",
|
|
"\n",
|
|
"\n",
|
|
" print(str_use_centroids)\n",
|
|
"\n",
|
|
"\n",
|
|
" print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')\n",
|
|
" if not use_centroids:\n",
|
|
" print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')\n",
|
|
" print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} % (Picks any class in the 5 closest embeddings)')\n",
|
|
" if not use_centroids:\n",
|
|
" print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} % (Picks the 5 closest distinct classes)')\n",
|
|
" print('\\n' + '#'*32 + '\\n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d2aaac6c",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Show some examples (only for WLASL)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b9d1d309",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from IPython.display import Video"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fd2a0cd8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():\n",
|
|
" display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.7.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|