Created guide and script to export embeddings
This commit is contained in:
20
README2.md
Normal file
20
README2.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Spoter Embeddings
|
||||
|
||||
## Creating dataset
|
||||
First, make a folder where all you're videos are located. When this is done, all keypoints can be extracted from the videos using the following command. This will extract the keypoints and store them in \<path-to-landmarks-folder\>.
|
||||
```
|
||||
python3 preprocessing.py extract --videos-folder <path-to-videos-folder> --output-folder <path-to-landmarks-folder>
|
||||
```
|
||||
|
||||
When this is done, the dataset can be created using the following command:
|
||||
```
|
||||
python3 preprocessing.py create --landmarks-dataset <path-to-landmarks-folder> --videos-folder <path-to-videos-folder> --dataset-folder <dataset-output-folder> (--create-new-split --test-size <test-percentage>)
|
||||
```
|
||||
The above command generates a train (and val) csv file which includes all the extracted keypoints. These can then be used to train or generates embeddings.
|
||||
|
||||
## Creating Embeddings
|
||||
The embeddings can be created using the following command:
|
||||
```
|
||||
python3 export_embeddings.py --checkpoint <checkpoint-path> --dataset <path-to-dataset> --output <embeddings-output-file>
|
||||
```
|
||||
The command above generates the embeddings for a given dataset and saves them as a csv file.
|
||||
91
export_embeddings.py
Normal file
91
export_embeddings.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
from datasets.dataset_loader import LocalDatasetLoader
|
||||
from datasets.embedding_dataset import SLREmbeddingDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import SLREmbeddingDataset, collate_fn_padd
|
||||
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||
import numpy as np
|
||||
import random
|
||||
import pandas as pd
|
||||
|
||||
seed = 43
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.use_deterministic_algorithms(True)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Export embeddings')
|
||||
parser.add_argument('--checkpoint', type=str, default=None, help='Path to checkpoint')
|
||||
parser.add_argument('--output', type=str, default=None, help='Path to output')
|
||||
parser.add_argument('--dataset', type=str, default=None, help='Path to data')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parse_args()
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
# load the model
|
||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
).to(device)
|
||||
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
dataset_loader = LocalDatasetLoader()
|
||||
dataset = SLREmbeddingDataset(args.dataset, triplet=False, augmentations=False)
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn_padd,
|
||||
pin_memory=torch.cuda.is_available(),
|
||||
num_workers=multiprocessing.cpu_count(),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
k = 0
|
||||
with torch.no_grad():
|
||||
for i, (inputs, labels, masks) in enumerate(data_loader):
|
||||
k += 1
|
||||
inputs = inputs.to(device)
|
||||
masks = masks.to(device)
|
||||
outputs = model(inputs, masks)
|
||||
|
||||
for n in range(outputs.shape[0]):
|
||||
embeddings.append(outputs[n].cpu().numpy())
|
||||
|
||||
df = pd.read_csv(args.dataset)
|
||||
df["embeddings"] = embeddings
|
||||
df = df[['embeddings', 'label_name', 'labels']]
|
||||
df['embeddings2'] = df['embeddings'].apply(lambda x: x.tolist())
|
||||
|
||||
|
||||
df.to_csv(args.output, index=False)
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 5,
|
||||
"id": "c20f7fd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -13,7 +13,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 6,
|
||||
"id": "ada032d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -22,13 +22,12 @@
|
||||
"import os\n",
|
||||
"import os.path as op\n",
|
||||
"import pandas as pd\n",
|
||||
"import json\n",
|
||||
"import base64"
|
||||
"import json"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 7,
|
||||
"id": "05682e73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -38,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 8,
|
||||
"id": "fede7684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -48,7 +47,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 9,
|
||||
"id": "ce531994",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -64,7 +63,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 10,
|
||||
"id": "f4a2d672",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -87,17 +86,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"id": "1d9db764",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<torch._C.Generator at 0x7fb050be7710>"
|
||||
"<torch._C.Generator at 0x7f010919d710>"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -119,7 +118,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 12,
|
||||
"id": "71224139",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -133,7 +132,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 13,
|
||||
"id": "013d3774",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -143,7 +142,7 @@
|
||||
"<All keys matched successfully>"
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -169,27 +168,28 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"execution_count": 24,
|
||||
"id": "ba6b58f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
||||
"if SL_DATASET == 'wlasl':\n",
|
||||
"SL_DATASET = 'basic-signs' # or 'wlasl'\n",
|
||||
"\n",
|
||||
"if SL_DATASET == 'fingerspelling':\n",
|
||||
" dataset_name = \"fingerspelling\"\n",
|
||||
" num_classes = 100\n",
|
||||
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
|
||||
"else:\n",
|
||||
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
||||
" num_classes = 64\n",
|
||||
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
||||
" \n",
|
||||
"elif SL_DATASET == 'wlasl':\n",
|
||||
" dataset_name = \"wlasl\"\n",
|
||||
" split_dataset_path = \"WLASL100_{}.csv\"\n",
|
||||
"elif SL_DATASET == 'basic-signs':\n",
|
||||
" dataset_name = \"basic-signs\"\n",
|
||||
" split_dataset_path = \"basic-signs_{}.csv\"\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"execution_count": 25,
|
||||
"id": "5643a72c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -209,7 +209,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"execution_count": 16,
|
||||
"id": "04a62088",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -222,7 +222,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 86,
|
||||
"execution_count": 17,
|
||||
"id": "79c837c1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -253,7 +253,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 87,
|
||||
"execution_count": 18,
|
||||
"id": "8b5bda73",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -280,17 +280,17 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 88,
|
||||
"execution_count": 19,
|
||||
"id": "0efa0871",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(560, 560)"
|
||||
"(164, 164)"
|
||||
]
|
||||
},
|
||||
"execution_count": 88,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -301,7 +301,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 91,
|
||||
"execution_count": 21,
|
||||
"id": "ab83c6e2",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for split in splits:\n",
|
||||
" df = dfs[split]\n",
|
||||
" df['embeddings'] = embeddings_split[split]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"id": "0b9fb9c2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -309,54 +323,42 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0 [0.4734516, -0.58630264, 0.18397862, -0.165259...\n",
|
||||
"1 [1.6672437, -2.3754091, -0.77506787, -0.666019...\n",
|
||||
"2 [1.7801772, -0.0077665895, 0.22098881, 0.09736...\n",
|
||||
"3 [-0.6503094, 0.14683367, 0.1253598, 0.5183654,...\n",
|
||||
"4 [1.2275296, -0.4874984, 0.56826925, -0.9628880...\n",
|
||||
"0 [1.7327625, -3.015248, -1.4775522, -0.7505071,...\n",
|
||||
"1 [2.0936582, -0.596195, -0.7918601, -0.15896143...\n",
|
||||
"2 [-1.4007742, -0.9608915, 1.3294879, -0.5185398...\n",
|
||||
"3 [1.3280737, -3.299126, -1.0110444, -1.2528414,...\n",
|
||||
"4 [-0.071124956, -0.79259753, 0.7182858, 0.38130...\n",
|
||||
" ... \n",
|
||||
"555 [-0.4408903, -0.9623146, 0.21583065, -0.381131...\n",
|
||||
"556 [1.7910445, -3.5434258, -1.332628, -0.95276725...\n",
|
||||
"557 [2.3283613, 0.11504881, -0.4955331, -0.4563401...\n",
|
||||
"558 [-1.0491562, -1.1793315, 0.3248821, 0.16679825...\n",
|
||||
"559 [1.447621, -1.2482919, 0.17936605, -1.4752473,...\n",
|
||||
"Name: embeddings, Length: 560, dtype: object\n",
|
||||
"0 B\n",
|
||||
"1 D\n",
|
||||
"2 X\n",
|
||||
"3 O\n",
|
||||
"4 W\n",
|
||||
" ..\n",
|
||||
"555 F\n",
|
||||
"556 X\n",
|
||||
"557 Z\n",
|
||||
"558 Y\n",
|
||||
"559 W\n",
|
||||
"Name: label_name, Length: 560, dtype: object\n",
|
||||
"159 [-1.5968355, 1.9617733, 0.28859574, 1.256657, ...\n",
|
||||
"160 [0.44801116, -1.8377966, 1.1004394, -1.195648,...\n",
|
||||
"161 [2.0584257, 1.6986116, 0.5129896, 0.27279535, ...\n",
|
||||
"162 [1.6695516, -2.967027, -1.5715427, -0.77170163...\n",
|
||||
"163 [1.4977738, -2.6278958, -1.6123883, -0.8420623...\n",
|
||||
"Name: embeddings, Length: 164, dtype: object\n",
|
||||
"0 TOT-ZIENS\n",
|
||||
"1 GOED\n",
|
||||
"2 GOEDENACHT\n",
|
||||
"3 NEE\n",
|
||||
"4 SLECHT\n",
|
||||
" ... \n",
|
||||
"159 SORRY\n",
|
||||
"160 GOEDEMORGEN\n",
|
||||
"161 LINKS\n",
|
||||
"162 TOT-ZIENS\n",
|
||||
"163 GOED\n",
|
||||
"Name: label_name, Length: 164, dtype: object\n",
|
||||
"0 0\n",
|
||||
"1 1\n",
|
||||
"2 2\n",
|
||||
"3 3\n",
|
||||
"4 5\n",
|
||||
"4 4\n",
|
||||
" ..\n",
|
||||
"555 24\n",
|
||||
"556 2\n",
|
||||
"557 14\n",
|
||||
"558 8\n",
|
||||
"559 5\n",
|
||||
"Name: labels, Length: 560, dtype: int64\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/tmp/ipykernel_969762/1944871806.py:9: SettingWithCopyWarning: \n",
|
||||
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
||||
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
||||
"\n",
|
||||
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
||||
" dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n"
|
||||
"159 7\n",
|
||||
"160 5\n",
|
||||
"161 13\n",
|
||||
"162 0\n",
|
||||
"163 1\n",
|
||||
"Name: labels, Length: 164, dtype: int64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -372,21 +374,7 @@
|
||||
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
|
||||
"\n",
|
||||
"# save the dfs['train']\n",
|
||||
"dfs['train'].to_csv('../data/fingerspelling/embeddings.csv', index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 90,
|
||||
"id": "ab83c6e2",
|
||||
"metadata": {
|
||||
"lines_to_next_cell": 2
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for split in splits:\n",
|
||||
" df = dfs[split]\n",
|
||||
" df['embeddings'] = embeddings_split[split]"
|
||||
"dfs['train'].to_csv(f'../data/{dataset_name}/embeddings.csv', index=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -400,7 +388,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 94,
|
||||
"execution_count": 23,
|
||||
"id": "7399b8ae",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -409,16 +397,16 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using centroids only\n",
|
||||
"Top-1 accuracy: 77.06 %\n",
|
||||
"Top-5 embeddings class match: 100.00 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-1 accuracy: 80.00 %\n",
|
||||
"Top-5 embeddings class match: 93.33 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"\n",
|
||||
"################################\n",
|
||||
"\n",
|
||||
"Using all embeddings\n",
|
||||
"Top-1 accuracy: 81.65 %\n",
|
||||
"5-nn accuracy: 83.49 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||
"Top-5 embeddings class match: 96.33 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-5 unique class match: 99.08 % (Picks the 5 closest distinct classes)\n",
|
||||
"Top-1 accuracy: 80.00 %\n",
|
||||
"5-nn accuracy: 80.00 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||
"Top-5 embeddings class match: 86.67 % (Picks any class in the 5 closest embeddings)\n",
|
||||
"Top-5 unique class match: 93.33 % (Picks the 5 closest distinct classes)\n",
|
||||
"\n",
|
||||
"################################\n",
|
||||
"\n"
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -152,19 +152,23 @@ def create(args):
|
||||
df = pd.concat([df, df_aux], axis=1)
|
||||
if args.create_new_split:
|
||||
df_train, df_test = train_test_split(df, test_size=test_size, stratify=df['labels'], random_state=42)
|
||||
else:
|
||||
print(df['split'].unique())
|
||||
df_train = df[(df['split'] == 'train') | (df['split'] == 'val')]
|
||||
df_test = df[df['split'] == 'test']
|
||||
|
||||
|
||||
print(f'Num classes: {num_classes}')
|
||||
print(df_train['labels'].value_counts())
|
||||
print(df_test['labels'].value_counts())
|
||||
assert set(df_train['labels'].unique()) == set(df_test['labels'].unique(
|
||||
)), 'The labels for train and test dataframe are different. We recommend to download the dataset again, or to use \
|
||||
the --create-new-split flag'
|
||||
for split, df_split in zip(['train', 'val'],
|
||||
[df_train, df_test]):
|
||||
fn_out = op.join(dataset_folder, f'fingerspelling_{split}.csv')
|
||||
fn_out = op.join(dataset_folder, f'{split}.csv')
|
||||
(df_split.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
|
||||
else:
|
||||
fn_out = op.join(dataset_folder, 'train.csv')
|
||||
(df.reset_index(drop=True)
|
||||
.applymap(convert_to_str)
|
||||
.to_csv(fn_out, index=False))
|
||||
8
train.sh
8
train.sh
@@ -1,14 +1,14 @@
|
||||
#!/bin/sh
|
||||
python -m train \
|
||||
--save_checkpoints_every 10 \
|
||||
--experiment_name "basic" \
|
||||
--experiment_name "wlasl" \
|
||||
--epochs 300 \
|
||||
--optimizer "ADAM" \
|
||||
--lr 0.0001 \
|
||||
--batch_size 16 \
|
||||
--dataset_name "GoogleWLASL" \
|
||||
--training_set_path "spoter_train.csv" \
|
||||
--validation_set_path "spoter_test.csv" \
|
||||
--dataset_name "WLASL" \
|
||||
--training_set_path "WLASL100_train.csv" \
|
||||
--validation_set_path "WLASL100_val.csv" \
|
||||
--vector_length 32 \
|
||||
--epoch_iters -1 \
|
||||
--scheduler_factor 0.2 \
|
||||
|
||||
Reference in New Issue
Block a user