Created guide and script to export embeddings

This commit is contained in:
2023-04-14 14:40:05 +00:00
parent 49ced1983d
commit 1f24df1b8f
6 changed files with 255 additions and 150 deletions

20
README2.md Normal file
View File

@@ -0,0 +1,20 @@
# Spoter Embeddings
## Creating dataset
First, make a folder where all you're videos are located. When this is done, all keypoints can be extracted from the videos using the following command. This will extract the keypoints and store them in \<path-to-landmarks-folder\>.
```
python3 preprocessing.py extract --videos-folder <path-to-videos-folder> --output-folder <path-to-landmarks-folder>
```
When this is done, the dataset can be created using the following command:
```
python3 preprocessing.py create --landmarks-dataset <path-to-landmarks-folder> --videos-folder <path-to-videos-folder> --dataset-folder <dataset-output-folder> (--create-new-split --test-size <test-percentage>)
```
The above command generates a train (and val) csv file which includes all the extracted keypoints. These can then be used to train or generates embeddings.
## Creating Embeddings
The embeddings can be created using the following command:
```
python3 export_embeddings.py --checkpoint <checkpoint-path> --dataset <path-to-dataset> --output <embeddings-output-file>
```
The command above generates the embeddings for a given dataset and saves them as a csv file.

91
export_embeddings.py Normal file
View File

@@ -0,0 +1,91 @@
import multiprocessing
import os
import torch
import argparse
from datasets.dataset_loader import LocalDatasetLoader
from datasets.embedding_dataset import SLREmbeddingDataset
from torch.utils.data import DataLoader
from datasets import SLREmbeddingDataset, collate_fn_padd
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
import numpy as np
import random
import pandas as pd
seed = 43
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
generator = torch.Generator()
generator.manual_seed(seed)
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
generator = torch.Generator()
generator.manual_seed(seed)
def parse_args():
parser = argparse.ArgumentParser(description='Export embeddings')
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
parser.add_argument('--output', type=str, default=None, help='Path to output')
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
args = parser.parse_args()
return args
args = parse_args()
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
# load the model
checkpoint = torch.load(args.checkpoint, map_location=device)
model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
).to(device)
model.load_state_dict(checkpoint["state_dict"])
dataset_loader = LocalDatasetLoader()
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
data_loader = DataLoader(
dataset,
batch_size=1,
shuffle=False,
collate_fn=collate_fn_padd,
pin_memory=torch.cuda.is_available(),
num_workers=multiprocessing.cpu_count(),
worker_init_fn=seed_worker,
generator=generator,
)
embeddings = []
k = 0
with torch.no_grad():
for i, (inputs, labels, masks) in enumerate(data_loader):
k += 1
inputs = inputs.to(device)
masks = masks.to(device)
outputs = model(inputs, masks)
for n in range(outputs.shape[0]):
embeddings.append(outputs[n].cpu().numpy())
df = pd.read_csv(args.dataset)
df["embeddings"] = embeddings
df = df[['embeddings', 'label_name', 'labels']]
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist())
df.to_csv(args.output, index=False)

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"id": "c20f7fd5",
"metadata": {},
"outputs": [],
@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "ada032d0",
"metadata": {},
"outputs": [],
@@ -22,13 +22,12 @@
"import os\n",
"import os.path as op\n",
"import pandas as pd\n",
"import json\n",
"import base64"
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 7,
"id": "05682e73",
"metadata": {},
"outputs": [],
@@ -38,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"id": "fede7684",
"metadata": {},
"outputs": [],
@@ -48,7 +47,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"id": "ce531994",
"metadata": {},
"outputs": [],
@@ -64,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"id": "f4a2d672",
"metadata": {},
"outputs": [],
@@ -87,17 +86,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "1d9db764",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fb050be7710>"
"<torch._C.Generator at 0x7f010919d710>"
]
},
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@@ -119,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"id": "71224139",
"metadata": {},
"outputs": [],
@@ -133,7 +132,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 13,
"id": "013d3774",
"metadata": {},
"outputs": [
@@ -143,7 +142,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 9,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -169,27 +168,28 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 24,
"id": "ba6b58f0",
"metadata": {},
"outputs": [],
"source": [
"SL_DATASET = 'wlasl' # or 'lsa'\n",
"if SL_DATASET == 'wlasl':\n",
" dataset_name = \"fingerspelling\"\n",
" num_classes = 100\n",
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
"else:\n",
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
" num_classes = 64\n",
" split_dataset_path = \"LSA64_{}.csv\"\n",
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
"\n",
"if SL_DATASET == 'fingerspelling':\n",
" dataset_name = \"fingerspelling\"\n",
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
"elif SL_DATASET == 'wlasl':\n",
" dataset_name = \"wlasl\"\n",
" split_dataset_path = \"WLASL100_{}.csv\"\n",
"elif SL_DATASET == 'basic-signs':\n",
" dataset_name = \"basic-signs\"\n",
" split_dataset_path = \"basic-signs_{}.csv\"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 25,
"id": "5643a72c",
"metadata": {},
"outputs": [],
@@ -209,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 16,
"id": "04a62088",
"metadata": {},
"outputs": [],
@@ -222,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 17,
"id": "79c837c1",
"metadata": {},
"outputs": [],
@@ -253,7 +253,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 18,
"id": "8b5bda73",
"metadata": {},
"outputs": [],
@@ -280,17 +280,17 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 19,
"id": "0efa0871",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(560, 560)"
"(164, 164)"
]
},
"execution_count": 88,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -301,7 +301,21 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 21,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"for split in splits:\n",
" df = dfs[split]\n",
" df['embeddings'] = embeddings_split[split]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0b9fb9c2",
"metadata": {},
"outputs": [
@@ -309,54 +323,42 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0 [0.4734516, -0.58630264, 0.18397862, -0.165259...\n",
"1 [1.6672437, -2.3754091, -0.77506787, -0.666019...\n",
"2 [1.7801772, -0.0077665895, 0.22098881, 0.09736...\n",
"3 [-0.6503094, 0.14683367, 0.1253598, 0.5183654,...\n",
"4 [1.2275296, -0.4874984, 0.56826925, -0.9628880...\n",
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
" ... \n",
"555 [-0.4408903, -0.9623146, 0.21583065, -0.381131...\n",
"556 [1.7910445, -3.5434258, -1.332628, -0.95276725...\n",
"557 [2.3283613, 0.11504881, -0.4955331, -0.4563401...\n",
"558 [-1.0491562, -1.1793315, 0.3248821, 0.16679825...\n",
"559 [1.447621, -1.2482919, 0.17936605, -1.4752473,...\n",
"Name: embeddings, Length: 560, dtype: object\n",
"0 B\n",
"1 D\n",
"2 X\n",
"3 O\n",
"4 W\n",
" ..\n",
"555 F\n",
"556 X\n",
"557 Z\n",
"558 Y\n",
"559 W\n",
"Name: label_name, Length: 560, dtype: object\n",
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
"Name: embeddings, Length: 164, dtype: object\n",
"0 TOT-ZIENS\n",
"1 GOED\n",
"2 GOEDENACHT\n",
"3 NEE\n",
"4 SLECHT\n",
" ... \n",
"159 SORRY\n",
"160 GOEDEMORGEN\n",
"161 LINKS\n",
"162 TOT-ZIENS\n",
"163 GOED\n",
"Name: label_name, Length: 164, dtype: object\n",
"0 0\n",
"1 1\n",
"2 2\n",
"3 3\n",
"4 5\n",
"4 4\n",
" ..\n",
"555 24\n",
"556 2\n",
"557 14\n",
"558 8\n",
"559 5\n",
"Name: labels, Length: 560, dtype: int64\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_969762/1944871806.py:9: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n"
"159 7\n",
"160 5\n",
"161 13\n",
"162 0\n",
"163 1\n",
"Name: labels, Length: 164, dtype: int64\n"
]
}
],
@@ -372,21 +374,7 @@
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
"\n",
"# save the dfs['train']\n",
"dfs['train'].to_csv('../data/fingerspelling/embeddings.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "ab83c6e2",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [],
"source": [
"for split in splits:\n",
" df = dfs[split]\n",
" df['embeddings'] = embeddings_split[split]"
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
]
},
{
@@ -400,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 23,
"id": "7399b8ae",
"metadata": {},
"outputs": [
@@ -409,16 +397,16 @@
"output_type": "stream",
"text": [
"Using centroids only\n",
"Top-1 accuracy: 77.06 %\n",
"Top-5 embeddings class match: 100.00 % (Picks any class in the 5 closest embeddings)\n",
"Top-1 accuracy: 80.00 %\n",
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
"\n",
"################################\n",
"\n",
"Using all embeddings\n",
"Top-1 accuracy: 81.65 %\n",
"5-nn accuracy: 83.49 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 96.33 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 99.08 % (Picks the 5 closest distinct classes)\n",
"Top-1 accuracy: 80.00 %\n",
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
"\n",
"################################\n",
"\n"

File diff suppressed because one or more lines are too long

View File

@@ -152,19 +152,23 @@ def create(args):
df = pd.concat([df, df_aux], axis=1)
if args.create_new_split:
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
else:
print(df['split'].unique())
df_train = df[(df['split'] == 'train') | (df['split'] == 'val')]
df_test = df[df['split'] == 'test']
print(f'Num classes: {num_classes}')
print(df_train['labels'].value_counts())
print(df_test['labels'].value_counts())
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
the --create-new-split flag'
for split, df_split in zip(['train', 'val'],
[df_train, df_test]):
fn_out = op.join(dataset_folder, f'fingerspelling_{split}.csv')
fn_out = op.join(dataset_folder, f'{split}.csv')
(df_split.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))
else:
fn_out = op.join(dataset_folder, 'train.csv')
(df.reset_index(drop=True)
.applymap(convert_to_str)
.to_csv(fn_out, index=False))

View File

@@ -1,14 +1,14 @@
#!/bin/sh
python -m train \
--save_checkpoints_every 10 \
--experiment_name "basic" \
--experiment_name "wlasl" \
--epochs 300 \
--optimizer "ADAM" \
--lr 0.0001 \
--batch_size 16 \
--dataset_name "GoogleWLASL" \
--training_set_path "spoter_train.csv" \
--validation_set_path "spoter_test.csv" \
--dataset_name "WLASL" \
--training_set_path "WLASL100_train.csv" \
--validation_set_path "WLASL100_val.csv" \
--vector_length 32 \
--epoch_iters -1 \
--scheduler_factor 0.2 \