{ "cells": [ { "cell_type": "code", "execution_count": 5, "id": "c20f7fd5", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 6, "id": "ada032d0", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "import os.path as op\n", "import pandas as pd\n", "import json" ] }, { "cell_type": "code", "execution_count": 7, "id": "05682e73", "metadata": {}, "outputs": [], "source": [ "sys.path.append(op.abspath('..'))" ] }, { "cell_type": "code", "execution_count": 8, "id": "fede7684", "metadata": {}, "outputs": [], "source": [ "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "ce531994", "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "from itertools import chain\n", "\n", "import torch\n", "import multiprocessing\n", "from scipy.spatial import distance_matrix\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 10, "id": "f4a2d672", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "from datasets import SLREmbeddingDataset, collate_fn_padd\n", "from datasets.dataset_loader import LocalDatasetLoader\n", "from models import embeddings_scatter_plot_splits\n", "from models import SPOTER_EMBEDDINGS" ] }, { "cell_type": "markdown", "id": "af8fbe32", "metadata": {}, "source": [ "## Model and dataset loading" ] }, { "cell_type": "code", "execution_count": 11, "id": "1d9db764", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "seed = 43\n", "random.seed(seed)\n", "np.random.seed(seed)\n", "os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", "torch.manual_seed(seed)\n", "torch.cuda.manual_seed(seed)\n", "torch.cuda.manual_seed_all(seed)\n", "torch.backends.cudnn.deterministic = True\n", "torch.use_deterministic_algorithms(True) \n", "generator = torch.Generator()\n", "generator.manual_seed(seed)" ] }, { "cell_type": "code", "execution_count": 12, "id": "71224139", "metadata": {}, "outputs": [], "source": [ "BASE_DATA_FOLDER = '../data/'\n", "os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n", "device = torch.device(\"cpu\")\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")" ] }, { "cell_type": "code", "execution_count": 13, "id": "013d3774", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# LOAD MODEL FROM CLEARML\n", "# from clearml import InputModel\n", "# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n", "# checkpoint = torch.load(model.get_weights())\n", "\n", "## Set your path to checkoint here\n", "CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n", "checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n", "\n", "model = SPOTER_EMBEDDINGS(\n", " features=checkpoint[\"config_args\"].vector_length,\n", " hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n", " norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n", ").to(device)\n", "\n", "model.load_state_dict(checkpoint[\"state_dict\"])" ] }, { "cell_type": "code", "execution_count": 24, "id": "ba6b58f0", "metadata": {}, "outputs": [], "source": [ "SL_DATASET = 'basic-signs' # or 'wlasl'\n", "\n", "if SL_DATASET == 'fingerspelling':\n", " dataset_name = \"fingerspelling\"\n", " split_dataset_path = \"fingerspelling_{}.csv\"\n", "elif SL_DATASET == 'wlasl':\n", " dataset_name = \"wlasl\"\n", " split_dataset_path = \"WLASL100_{}.csv\"\n", "elif SL_DATASET == 'basic-signs':\n", " dataset_name = \"basic-signs\"\n", " split_dataset_path = \"basic-signs_{}.csv\"\n", " " ] }, { "cell_type": "code", "execution_count": 25, "id": "5643a72c", "metadata": {}, "outputs": [], "source": [ "def get_dataset_loader(loader_name=None):\n", " if loader_name == 'CLEARML':\n", " from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n", " return ClearMLDatasetLoader()\n", " else:\n", " return LocalDatasetLoader()\n", "\n", "dataset_loader = get_dataset_loader()\n", "dataset_project = \"Sign Language Recognition\"\n", "batch_size = 1\n", "dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)" ] }, { "cell_type": "code", "execution_count": 16, "id": "04a62088", "metadata": {}, "outputs": [], "source": [ "def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)" ] }, { "cell_type": "code", "execution_count": 17, "id": "79c837c1", "metadata": {}, "outputs": [], "source": [ "dataloaders = {}\n", "splits = ['train', 'val']\n", "dfs = {}\n", "for split in splits:\n", " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", " split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)\n", " data_loader = DataLoader(\n", " split_set,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " collate_fn=collate_fn_padd,\n", " pin_memory=torch.cuda.is_available(),\n", " num_workers=multiprocessing.cpu_count(),\n", " worker_init_fn=seed_worker,\n", " generator=generator,\n", " )\n", " dataloaders[split] = data_loader\n", " dfs[split] = pd.read_csv(split_set_path)\n", "\n", "with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n", " id_to_label = json.load(fid)\n", "id_to_label = {int(key): value for key, value in id_to_label.items()}" ] }, { "cell_type": "code", "execution_count": 18, "id": "8b5bda73", "metadata": {}, "outputs": [], "source": [ "labels_split = {}\n", "embeddings_split = {}\n", "splits = list(dataloaders.keys())\n", "with torch.no_grad():\n", " for split, dataloader in dataloaders.items():\n", " labels_str = []\n", " embeddings = []\n", " k = 0\n", " for i, (inputs, labels, masks) in enumerate(dataloader):\n", " k += 1\n", " inputs = inputs.to(device)\n", " \n", "\n", " masks = masks.to(device)\n", " outputs = model(inputs, masks)\n", " for n in range(outputs.shape[0]):\n", " embeddings.append(outputs[n, 0].cpu().detach().numpy())\n", " embeddings_split[split] = embeddings" ] }, { "cell_type": "code", "execution_count": 19, "id": "0efa0871", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(164, 164)" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(embeddings_split['train']), len(dfs['train'])" ] }, { "cell_type": "code", "execution_count": 21, "id": "ab83c6e2", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "for split in splits:\n", " df = dfs[split]\n", " df['embeddings'] = embeddings_split[split]" ] }, { "cell_type": "code", "execution_count": 26, "id": "0b9fb9c2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n", "1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n", "2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n", "3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n", "4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n", " ... \n", "159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n", "160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n", "161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n", "162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n", "163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n", "Name: embeddings, Length: 164, dtype: object\n", "0 TOT-ZIENS\n", "1 GOED\n", "2 GOEDENACHT\n", "3 NEE\n", "4 SLECHT\n", " ... \n", "159 SORRY\n", "160 GOEDEMORGEN\n", "161 LINKS\n", "162 TOT-ZIENS\n", "163 GOED\n", "Name: label_name, Length: 164, dtype: object\n", "0 0\n", "1 1\n", "2 2\n", "3 3\n", "4 4\n", " ..\n", "159 7\n", "160 5\n", "161 13\n", "162 0\n", "163 1\n", "Name: labels, Length: 164, dtype: int64\n" ] } ], "source": [ "print(dfs['train'][\"embeddings\"])\n", "print(dfs['train'][\"label_name\"])\n", "print(dfs['train'][\"labels\"])\n", "\n", "# only keep these columns\n", "dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n", "\n", "# convert embeddings to string\n", "dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n", "\n", "# save the dfs['train']\n", "dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)" ] }, { "cell_type": "markdown", "id": "2951638d", "metadata": {}, "source": [ "## Compute metrics\n", "Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "7399b8ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using centroids only\n", "Top-1 accuracy: 80.00 %\n", "Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n", "\n", "################################\n", "\n", "Using all embeddings\n", "Top-1 accuracy: 80.00 %\n", "5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n", "Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n", "Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n", "\n", "################################\n", "\n" ] } ], "source": [ "for use_centroids, str_use_centroids in zip([True, False],\n", " ['Using centroids only', 'Using all embeddings']):\n", "\n", " df_val = dfs['val']\n", " df_train = dfs['train']\n", " if use_centroids:\n", " df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()\n", " x_train = np.vstack(df_train['embeddings'])\n", " x_val = np.vstack(df_val['embeddings'])\n", "\n", " d_mat = distance_matrix(x_val, x_train, p=2)\n", "\n", " top5_embs = 0\n", " top5_classes = 0\n", " knn = 0\n", " top1 = 0\n", "\n", " len_val_dataset = len(df_val)\n", " good_samples = []\n", "\n", " for i in range(d_mat.shape[0]):\n", " true_label = df_val.loc[i, 'labels']\n", " labels = df_train['labels'].values\n", " argsort = np.argsort(d_mat[i])\n", " sorted_labels = labels[argsort]\n", " if sorted_labels[0] == true_label:\n", " top1 += 1\n", " # if use_centroids:\n", " # good_samples.append(df_val.loc[i, 'video_id'])\n", " # else:\n", " # good_samples.append((df_val.loc[i, 'video_id'],\n", " # df_train.loc[argsort[0], 'video_id'],\n", " # i,\n", " # argsort[0]))\n", "\n", "\n", " if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n", " knn += 1\n", " if true_label in sorted_labels[:5]:\n", " top5_embs += 1\n", " if true_label in list(dict.fromkeys(sorted_labels))[:5]:\n", " top5_classes += 1\n", " else:\n", " continue\n", "\n", "\n", " print(str_use_centroids)\n", "\n", "\n", " print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')\n", " if not use_centroids:\n", " print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')\n", " print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} % (Picks any class in the 5 closest embeddings)')\n", " if not use_centroids:\n", " print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} % (Picks the 5 closest distinct classes)')\n", " print('\\n' + '#'*32 + '\\n')" ] }, { "cell_type": "markdown", "id": "d2aaac6c", "metadata": {}, "source": [ "## Show some examples (only for WLASL)" ] }, { "cell_type": "code", "execution_count": 22, "id": "b9d1d309", "metadata": {}, "outputs": [], "source": [ "from IPython.display import Video" ] }, { "cell_type": "code", "execution_count": 23, "id": "fd2a0cd8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():\n", " display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }