{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c20f7fd5", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "ada032d0", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import os\n", "import os.path as op\n", "import pandas as pd\n", "import json\n", "import base64" ] }, { "cell_type": "code", "execution_count": 3, "id": "05682e73", "metadata": {}, "outputs": [], "source": [ "sys.path.append(op.abspath('..'))" ] }, { "cell_type": "code", "execution_count": 4, "id": "fede7684", "metadata": {}, "outputs": [], "source": [ "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "ce531994", "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "from itertools import chain\n", "\n", "import torch\n", "import multiprocessing\n", "from scipy.spatial import distance_matrix\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 6, "id": "f4a2d672", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "from datasets import SLREmbeddingDataset, collate_fn_padd\n", "from datasets.dataset_loader import LocalDatasetLoader\n", "from models import embeddings_scatter_plot_splits\n", "from models import SPOTER_EMBEDDINGS" ] }, { "cell_type": "markdown", "id": "af8fbe32", "metadata": {}, "source": [ "## Model and dataset loading" ] }, { "cell_type": "code", "execution_count": 7, "id": "1d9db764", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import random\n", "seed = 43\n", "random.seed(seed)\n", "np.random.seed(seed)\n", "os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", "torch.manual_seed(seed)\n", "torch.cuda.manual_seed(seed)\n", "torch.cuda.manual_seed_all(seed)\n", "torch.backends.cudnn.deterministic = True\n", "torch.use_deterministic_algorithms(True) \n", "generator = torch.Generator()\n", "generator.manual_seed(seed)" ] }, { "cell_type": "code", "execution_count": 8, "id": "71224139", "metadata": {}, "outputs": [], "source": [ "BASE_DATA_FOLDER = '../data/'\n", "os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n", "device = torch.device(\"cpu\")\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "013d3774", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# LOAD MODEL FROM CLEARML\n", "# from clearml import InputModel\n", "# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n", "# checkpoint = torch.load(model.get_weights())\n", "\n", "## Set your path to checkoint here\n", "CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n", "checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n", "\n", "model = SPOTER_EMBEDDINGS(\n", " features=checkpoint[\"config_args\"].vector_length,\n", " hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n", " norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n", ").to(device)\n", "\n", "model.load_state_dict(checkpoint[\"state_dict\"])" ] }, { "cell_type": "code", "execution_count": 75, "id": "ba6b58f0", "metadata": {}, "outputs": [], "source": [ "SL_DATASET = 'wlasl' # or 'lsa'\n", "if SL_DATASET == 'wlasl':\n", " dataset_name = \"fingerspelling\"\n", " num_classes = 100\n", " split_dataset_path = \"fingerspelling_{}.csv\"\n", "else:\n", " dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n", " num_classes = 64\n", " split_dataset_path = \"LSA64_{}.csv\"\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 76, "id": "5643a72c", "metadata": {}, "outputs": [], "source": [ "def get_dataset_loader(loader_name=None):\n", " if loader_name == 'CLEARML':\n", " from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n", " return ClearMLDatasetLoader()\n", " else:\n", " return LocalDatasetLoader()\n", "\n", "dataset_loader = get_dataset_loader()\n", "dataset_project = \"Sign Language Recognition\"\n", "batch_size = 1\n", "dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)" ] }, { "cell_type": "code", "execution_count": 77, "id": "04a62088", "metadata": {}, "outputs": [], "source": [ "def seed_worker(worker_id):\n", " worker_seed = torch.initial_seed() % 2**32\n", " np.random.seed(worker_seed)\n", " random.seed(worker_seed)" ] }, { "cell_type": "code", "execution_count": 86, "id": "79c837c1", "metadata": {}, "outputs": [], "source": [ "dataloaders = {}\n", "splits = ['train', 'val']\n", "dfs = {}\n", "for split in splits:\n", " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", " split_set = SLREmbeddingDataset(split_set_path, triplet=False, augmentations=False)\n", " data_loader = DataLoader(\n", " split_set,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " collate_fn=collate_fn_padd,\n", " pin_memory=torch.cuda.is_available(),\n", " num_workers=multiprocessing.cpu_count(),\n", " worker_init_fn=seed_worker,\n", " generator=generator,\n", " )\n", " dataloaders[split] = data_loader\n", " dfs[split] = pd.read_csv(split_set_path)\n", "\n", "with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n", " id_to_label = json.load(fid)\n", "id_to_label = {int(key): value for key, value in id_to_label.items()}" ] }, { "cell_type": "code", "execution_count": 87, "id": "8b5bda73", "metadata": {}, "outputs": [], "source": [ "labels_split = {}\n", "embeddings_split = {}\n", "splits = list(dataloaders.keys())\n", "with torch.no_grad():\n", " for split, dataloader in dataloaders.items():\n", " labels_str = []\n", " embeddings = []\n", " k = 0\n", " for i, (inputs, labels, masks) in enumerate(dataloader):\n", " k += 1\n", " inputs = inputs.to(device)\n", " \n", "\n", " masks = masks.to(device)\n", " outputs = model(inputs, masks)\n", " for n in range(outputs.shape[0]):\n", " embeddings.append(outputs[n, 0].cpu().detach().numpy())\n", " embeddings_split[split] = embeddings" ] }, { "cell_type": "code", "execution_count": 88, "id": "0efa0871", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(560, 560)" ] }, "execution_count": 88, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(embeddings_split['train']), len(dfs['train'])" ] }, { "cell_type": "code", "execution_count": 91, "id": "0b9fb9c2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0 [0.4734516, -0.58630264, 0.18397862, -0.165259...\n", "1 [1.6672437, -2.3754091, -0.77506787, -0.666019...\n", "2 [1.7801772, -0.0077665895, 0.22098881, 0.09736...\n", "3 [-0.6503094, 0.14683367, 0.1253598, 0.5183654,...\n", "4 [1.2275296, -0.4874984, 0.56826925, -0.9628880...\n", " ... \n", "555 [-0.4408903, -0.9623146, 0.21583065, -0.381131...\n", "556 [1.7910445, -3.5434258, -1.332628, -0.95276725...\n", "557 [2.3283613, 0.11504881, -0.4955331, -0.4563401...\n", "558 [-1.0491562, -1.1793315, 0.3248821, 0.16679825...\n", "559 [1.447621, -1.2482919, 0.17936605, -1.4752473,...\n", "Name: embeddings, Length: 560, dtype: object\n", "0 B\n", "1 D\n", "2 X\n", "3 O\n", "4 W\n", " ..\n", "555 F\n", "556 X\n", "557 Z\n", "558 Y\n", "559 W\n", "Name: label_name, Length: 560, dtype: object\n", "0 0\n", "1 1\n", "2 2\n", "3 3\n", "4 5\n", " ..\n", "555 24\n", "556 2\n", "557 14\n", "558 8\n", "559 5\n", "Name: labels, Length: 560, dtype: int64\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_969762/1944871806.py:9: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n" ] } ], "source": [ "print(dfs['train'][\"embeddings\"])\n", "print(dfs['train'][\"label_name\"])\n", "print(dfs['train'][\"labels\"])\n", "\n", "# only keep these columns\n", "dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n", "\n", "# convert embeddings to string\n", "dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n", "\n", "# save the dfs['train']\n", "dfs['train'].to_csv('../data/fingerspelling/embeddings.csv', index=False)" ] }, { "cell_type": "code", "execution_count": 90, "id": "ab83c6e2", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "for split in splits:\n", " df = dfs[split]\n", " df['embeddings'] = embeddings_split[split]" ] }, { "cell_type": "markdown", "id": "2951638d", "metadata": {}, "source": [ "## Compute metrics\n", "Here computing top1 and top5 metrics either by using only a class centroid or by using the whole dataset to classify vectors.\n" ] }, { "cell_type": "code", "execution_count": 94, "id": "7399b8ae", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using centroids only\n", "Top-1 accuracy: 77.06 %\n", "Top-5 embeddings class match: 100.00 % (Picks any class in the 5 closest embeddings)\n", "\n", "################################\n", "\n", "Using all embeddings\n", "Top-1 accuracy: 81.65 %\n", "5-nn accuracy: 83.49 % (Picks the class that appears most often in the 5 closest embeddings)\n", "Top-5 embeddings class match: 96.33 % (Picks any class in the 5 closest embeddings)\n", "Top-5 unique class match: 99.08 % (Picks the 5 closest distinct classes)\n", "\n", "################################\n", "\n" ] } ], "source": [ "for use_centroids, str_use_centroids in zip([True, False],\n", " ['Using centroids only', 'Using all embeddings']):\n", "\n", " df_val = dfs['val']\n", " df_train = dfs['train']\n", " if use_centroids:\n", " df_train = dfs['train'].groupby('labels')['embeddings'].apply(np.mean).reset_index()\n", " x_train = np.vstack(df_train['embeddings'])\n", " x_val = np.vstack(df_val['embeddings'])\n", "\n", " d_mat = distance_matrix(x_val, x_train, p=2)\n", "\n", " top5_embs = 0\n", " top5_classes = 0\n", " knn = 0\n", " top1 = 0\n", "\n", " len_val_dataset = len(df_val)\n", " good_samples = []\n", "\n", " for i in range(d_mat.shape[0]):\n", " true_label = df_val.loc[i, 'labels']\n", " labels = df_train['labels'].values\n", " argsort = np.argsort(d_mat[i])\n", " sorted_labels = labels[argsort]\n", " if sorted_labels[0] == true_label:\n", " top1 += 1\n", " # if use_centroids:\n", " # good_samples.append(df_val.loc[i, 'video_id'])\n", " # else:\n", " # good_samples.append((df_val.loc[i, 'video_id'],\n", " # df_train.loc[argsort[0], 'video_id'],\n", " # i,\n", " # argsort[0]))\n", "\n", "\n", " if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n", " knn += 1\n", " if true_label in sorted_labels[:5]:\n", " top5_embs += 1\n", " if true_label in list(dict.fromkeys(sorted_labels))[:5]:\n", " top5_classes += 1\n", " else:\n", " continue\n", "\n", "\n", " print(str_use_centroids)\n", "\n", "\n", " print(f'Top-1 accuracy: {100 * top1 / len_val_dataset : 0.2f} %')\n", " if not use_centroids:\n", " print(f'5-nn accuracy: {100 * knn / len_val_dataset : 0.2f} % (Picks the class that appears most often in the 5 closest embeddings)')\n", " print(f'Top-5 embeddings class match: {100 * top5_embs / len_val_dataset: 0.2f} % (Picks any class in the 5 closest embeddings)')\n", " if not use_centroids:\n", " print(f'Top-5 unique class match: {100 * top5_classes / len_val_dataset: 0.2f} % (Picks the 5 closest distinct classes)')\n", " print('\\n' + '#'*32 + '\\n')" ] }, { "cell_type": "markdown", "id": "d2aaac6c", "metadata": {}, "source": [ "## Show some examples (only for WLASL)" ] }, { "cell_type": "code", "execution_count": 22, "id": "b9d1d309", "metadata": {}, "outputs": [], "source": [ "from IPython.display import Video" ] }, { "cell_type": "code", "execution_count": 23, "id": "fd2a0cd8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for row in df_train[df_train.label_name == 'thursday'][:3].itertuples():\n", " display(Video(op.join(BASE_DATA_FOLDER, f'wlasl/videos/{row.video_id}.mp4'), embed=True))" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }