Updated some files for alphabet visualization
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -155,3 +155,4 @@ out-img/
|
|||||||
converted_models/
|
converted_models/
|
||||||
*.pth
|
*.pth
|
||||||
*.onnx
|
*.onnx
|
||||||
|
.devcontainer
|
||||||
|
|||||||
@@ -94,7 +94,7 @@
|
|||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"<torch._C.Generator at 0x7f29f89e3ed0>"
|
"<torch._C.Generator at 0x7fb050be7710>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 7,
|
"execution_count": 7,
|
||||||
@@ -133,7 +133,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 9,
|
||||||
"id": "013d3774",
|
"id": "013d3774",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -143,7 +143,7 @@
|
|||||||
"<All keys matched successfully>"
|
"<All keys matched successfully>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 13,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -155,7 +155,7 @@
|
|||||||
"# checkpoint = torch.load(model.get_weights())\n",
|
"# checkpoint = torch.load(model.get_weights())\n",
|
||||||
"\n",
|
"\n",
|
||||||
"## Set your path to checkoint here\n",
|
"## Set your path to checkoint here\n",
|
||||||
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_6.pth\"\n",
|
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth\"\n",
|
||||||
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = SPOTER_EMBEDDINGS(\n",
|
"model = SPOTER_EMBEDDINGS(\n",
|
||||||
@@ -169,16 +169,16 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 75,
|
||||||
"id": "ba6b58f0",
|
"id": "ba6b58f0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
||||||
"if SL_DATASET == 'wlasl':\n",
|
"if SL_DATASET == 'wlasl':\n",
|
||||||
" dataset_name = \"wlasl\"\n",
|
" dataset_name = \"fingerspelling\"\n",
|
||||||
" num_classes = 100\n",
|
" num_classes = 100\n",
|
||||||
" split_dataset_path = \"WLASL100_train.csv\"\n",
|
" split_dataset_path = \"fingerspelling_{}.csv\"\n",
|
||||||
"else:\n",
|
"else:\n",
|
||||||
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
||||||
" num_classes = 64\n",
|
" num_classes = 64\n",
|
||||||
@@ -189,7 +189,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 76,
|
||||||
"id": "5643a72c",
|
"id": "5643a72c",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -209,7 +209,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 77,
|
||||||
"id": "04a62088",
|
"id": "04a62088",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -222,13 +222,13 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 86,
|
||||||
"id": "79c837c1",
|
"id": "79c837c1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"dataloaders = {}\n",
|
"dataloaders = {}\n",
|
||||||
"splits = ['train', 'val']\n",
|
"splits = ['train', 'val']\n",
|
||||||
"dfs = {}\n",
|
"dfs = {}\n",
|
||||||
"for split in splits:\n",
|
"for split in splits:\n",
|
||||||
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
||||||
@@ -253,7 +253,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 87,
|
||||||
"id": "8b5bda73",
|
"id": "8b5bda73",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -269,6 +269,8 @@
|
|||||||
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
" for i, (inputs, labels, masks) in enumerate(dataloader):\n",
|
||||||
" k += 1\n",
|
" k += 1\n",
|
||||||
" inputs = inputs.to(device)\n",
|
" inputs = inputs.to(device)\n",
|
||||||
|
" \n",
|
||||||
|
"\n",
|
||||||
" masks = masks.to(device)\n",
|
" masks = masks.to(device)\n",
|
||||||
" outputs = model(inputs, masks)\n",
|
" outputs = model(inputs, masks)\n",
|
||||||
" for n in range(outputs.shape[0]):\n",
|
" for n in range(outputs.shape[0]):\n",
|
||||||
@@ -278,17 +280,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 88,
|
||||||
"id": "0efa0871",
|
"id": "0efa0871",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"text/plain": [
|
"text/plain": [
|
||||||
"(810, 810)"
|
"(560, 560)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 19,
|
"execution_count": 88,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -299,7 +301,83 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 91,
|
||||||
|
"id": "0b9fb9c2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"0 [0.4734516, -0.58630264, 0.18397862, -0.165259...\n",
|
||||||
|
"1 [1.6672437, -2.3754091, -0.77506787, -0.666019...\n",
|
||||||
|
"2 [1.7801772, -0.0077665895, 0.22098881, 0.09736...\n",
|
||||||
|
"3 [-0.6503094, 0.14683367, 0.1253598, 0.5183654,...\n",
|
||||||
|
"4 [1.2275296, -0.4874984, 0.56826925, -0.9628880...\n",
|
||||||
|
" ... \n",
|
||||||
|
"555 [-0.4408903, -0.9623146, 0.21583065, -0.381131...\n",
|
||||||
|
"556 [1.7910445, -3.5434258, -1.332628, -0.95276725...\n",
|
||||||
|
"557 [2.3283613, 0.11504881, -0.4955331, -0.4563401...\n",
|
||||||
|
"558 [-1.0491562, -1.1793315, 0.3248821, 0.16679825...\n",
|
||||||
|
"559 [1.447621, -1.2482919, 0.17936605, -1.4752473,...\n",
|
||||||
|
"Name: embeddings, Length: 560, dtype: object\n",
|
||||||
|
"0 B\n",
|
||||||
|
"1 D\n",
|
||||||
|
"2 X\n",
|
||||||
|
"3 O\n",
|
||||||
|
"4 W\n",
|
||||||
|
" ..\n",
|
||||||
|
"555 F\n",
|
||||||
|
"556 X\n",
|
||||||
|
"557 Z\n",
|
||||||
|
"558 Y\n",
|
||||||
|
"559 W\n",
|
||||||
|
"Name: label_name, Length: 560, dtype: object\n",
|
||||||
|
"0 0\n",
|
||||||
|
"1 1\n",
|
||||||
|
"2 2\n",
|
||||||
|
"3 3\n",
|
||||||
|
"4 5\n",
|
||||||
|
" ..\n",
|
||||||
|
"555 24\n",
|
||||||
|
"556 2\n",
|
||||||
|
"557 14\n",
|
||||||
|
"558 8\n",
|
||||||
|
"559 5\n",
|
||||||
|
"Name: labels, Length: 560, dtype: int64\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/tmp/ipykernel_969762/1944871806.py:9: SettingWithCopyWarning: \n",
|
||||||
|
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
|
||||||
|
"Try using .loc[row_indexer,col_indexer] = value instead\n",
|
||||||
|
"\n",
|
||||||
|
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
|
||||||
|
" dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(dfs['train'][\"embeddings\"])\n",
|
||||||
|
"print(dfs['train'][\"label_name\"])\n",
|
||||||
|
"print(dfs['train'][\"labels\"])\n",
|
||||||
|
"\n",
|
||||||
|
"# only keep these columns\n",
|
||||||
|
"dfs['train'] = dfs['train'][['embeddings', 'label_name', 'labels']]\n",
|
||||||
|
"\n",
|
||||||
|
"# convert embeddings to string\n",
|
||||||
|
"dfs['train']['embeddings2'] = dfs['train']['embeddings'].apply(lambda x: x.tolist())\n",
|
||||||
|
"\n",
|
||||||
|
"# save the dfs['train']\n",
|
||||||
|
"dfs['train'].to_csv('../data/fingerspelling/embeddings.csv', index=False)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 90,
|
||||||
"id": "ab83c6e2",
|
"id": "ab83c6e2",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"lines_to_next_cell": 2
|
"lines_to_next_cell": 2
|
||||||
@@ -322,7 +400,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 94,
|
||||||
"id": "7399b8ae",
|
"id": "7399b8ae",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -331,16 +409,16 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Using centroids only\n",
|
"Using centroids only\n",
|
||||||
"Top-1 accuracy: 5.19 %\n",
|
"Top-1 accuracy: 77.06 %\n",
|
||||||
"Top-5 embeddings class match: 17.65 % (Picks any class in the 5 closest embeddings)\n",
|
"Top-5 embeddings class match: 100.00 % (Picks any class in the 5 closest embeddings)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"################################\n",
|
"################################\n",
|
||||||
"\n",
|
"\n",
|
||||||
"Using all embeddings\n",
|
"Using all embeddings\n",
|
||||||
"Top-1 accuracy: 5.31 %\n",
|
"Top-1 accuracy: 81.65 %\n",
|
||||||
"5-nn accuracy: 5.56 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
"5-nn accuracy: 83.49 % (Picks the class that appears most often in the 5 closest embeddings)\n",
|
||||||
"Top-5 embeddings class match: 15.43 % (Picks any class in the 5 closest embeddings)\n",
|
"Top-5 embeddings class match: 96.33 % (Picks any class in the 5 closest embeddings)\n",
|
||||||
"Top-5 unique class match: 15.56 % (Picks the 5 closest distinct classes)\n",
|
"Top-5 unique class match: 99.08 % (Picks the 5 closest distinct classes)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"################################\n",
|
"################################\n",
|
||||||
"\n"
|
"\n"
|
||||||
@@ -375,13 +453,13 @@
|
|||||||
" sorted_labels = labels[argsort]\n",
|
" sorted_labels = labels[argsort]\n",
|
||||||
" if sorted_labels[0] == true_label:\n",
|
" if sorted_labels[0] == true_label:\n",
|
||||||
" top1 += 1\n",
|
" top1 += 1\n",
|
||||||
" if use_centroids:\n",
|
" # if use_centroids:\n",
|
||||||
" good_samples.append(df_val.loc[i, 'video_id'])\n",
|
" # good_samples.append(df_val.loc[i, 'video_id'])\n",
|
||||||
" else:\n",
|
" # else:\n",
|
||||||
" good_samples.append((df_val.loc[i, 'video_id'],\n",
|
" # good_samples.append((df_val.loc[i, 'video_id'],\n",
|
||||||
" df_train.loc[argsort[0], 'video_id'],\n",
|
" # df_train.loc[argsort[0], 'video_id'],\n",
|
||||||
" i,\n",
|
" # i,\n",
|
||||||
" argsort[0]))\n",
|
" # argsort[0]))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
" if true_label == Counter(sorted_labels[:5]).most_common()[0][0]:\n",
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -143,9 +143,7 @@ def create(args):
|
|||||||
lmks_data.append(lmks_dict)
|
lmks_data.append(lmks_dict)
|
||||||
|
|
||||||
df_lmks = pd.DataFrame(lmks_data)
|
df_lmks = pd.DataFrame(lmks_data)
|
||||||
print(df_lmks)
|
|
||||||
df = pd.merge(df_video, df_lmks)
|
df = pd.merge(df_video, df_lmks)
|
||||||
print(df)
|
|
||||||
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
aux_columns = ['split', 'video_id', 'labels', 'label_name']
|
||||||
if videos_folder is not None:
|
if videos_folder is not None:
|
||||||
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
aux_columns += ['video_width', 'video_height', 'fps', 'length']
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user