491 lines
27 KiB
Plaintext
491 lines
27 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "8ef5cd92",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"%load_ext autoreload\n",
|
|
"%autoreload 2"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "78c4643a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import sys\n",
|
|
"import os.path as op\n",
|
|
"import pandas as pd\n",
|
|
"import json\n",
|
|
"import base64"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "ffba4333",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"sys.path.append(op.abspath('..'))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "5bc81f71",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "3de8bcf2",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import multiprocessing\n",
|
|
"from itertools import chain\n",
|
|
"import numpy as np\n",
|
|
"import random"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "91a045ba",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from bokeh.io import output_notebook, output_file\n",
|
|
"from bokeh.plotting import figure, show\n",
|
|
"from bokeh.models import LinearColorMapper, ColumnDataSource\n",
|
|
"from bokeh.transform import factor_cmap, factor_mark\n",
|
|
"from torch.utils.data import DataLoader\n",
|
|
"\n",
|
|
"\n",
|
|
"from datasets import SLREmbeddingDataset, collate_fn_padd\n",
|
|
"from datasets.dataset_loader import LocalDatasetLoader\n",
|
|
"from models import embeddings_scatter_plot_splits\n",
|
|
"from models import SPOTER_EMBEDDINGS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "bc50c296",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<torch._C.Generator at 0x7fe4a6429f50>"
|
|
]
|
|
},
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"seed = 43\n",
|
|
"random.seed(seed)\n",
|
|
"np.random.seed(seed)\n",
|
|
"os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed_all(seed)\n",
|
|
"torch.backends.cudnn.deterministic = True\n",
|
|
"torch.use_deterministic_algorithms(True) \n",
|
|
"generator = torch.Generator()\n",
|
|
"generator.manual_seed(seed)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "82766a17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BASE_DATA_FOLDER = '../data/'\n",
|
|
"os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n",
|
|
"device = torch.device(\"cpu\")\n",
|
|
"if torch.cuda.is_available():\n",
|
|
" device = torch.device(\"cuda\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "ead15a36",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<All keys matched successfully>"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"# LOAD MODEL FROM CLEARML\n",
|
|
"# from clearml import InputModel\n",
|
|
"# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n",
|
|
"# checkpoint = torch.load(model.get_weights())\n",
|
|
"\n",
|
|
"\n",
|
|
"CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_18.pth\"\n",
|
|
"checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n",
|
|
"\n",
|
|
"\n",
|
|
"model = SPOTER_EMBEDDINGS(\n",
|
|
" features=checkpoint[\"config_args\"].vector_length,\n",
|
|
" hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n",
|
|
" norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n",
|
|
").to(device)\n",
|
|
"\n",
|
|
"model.load_state_dict(checkpoint[\"state_dict\"])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "20f8036d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"SL_DATASET = 'wlasl' # or 'lsa'\n",
|
|
"if SL_DATASET == 'wlasl':\n",
|
|
" dataset_name = \"processed\"\n",
|
|
" num_classes = 15\n",
|
|
" split_dataset_path = \"spoter_test.csv\"\n",
|
|
"else:\n",
|
|
" dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n",
|
|
" num_classes = 64\n",
|
|
" split_dataset_path = \"LSA64_{}.csv\"\n",
|
|
" "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "758716b6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_dataset_loader(loader_name=None):\n",
|
|
" if loader_name == 'CLEARML':\n",
|
|
" from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n",
|
|
" return ClearMLDatasetLoader()\n",
|
|
" else:\n",
|
|
" return LocalDatasetLoader()\n",
|
|
"\n",
|
|
"dataset_loader = get_dataset_loader()\n",
|
|
"dataset_project = \"Sign Language Recognition\"\n",
|
|
"batch_size = 1\n",
|
|
"dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "f1527959",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/usr/local/lib/python3.8/dist-packages/sklearn/manifold/_t_sne.py:780: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.\n",
|
|
" warnings.warn(\n",
|
|
"/usr/local/lib/python3.8/dist-packages/sklearn/manifold/_t_sne.py:790: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"dataloaders = {}\n",
|
|
"splits = ['train', 'val']\n",
|
|
"dfs = {}\n",
|
|
"for split in splits:\n",
|
|
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
|
" split_set = SLREmbeddingDataset(split_set_path, triplet=False)\n",
|
|
" data_loader = DataLoader(\n",
|
|
" split_set,\n",
|
|
" batch_size=batch_size,\n",
|
|
" shuffle=False,\n",
|
|
" collate_fn=collate_fn_padd,\n",
|
|
" pin_memory=torch.cuda.is_available(),\n",
|
|
" num_workers=multiprocessing.cpu_count()\n",
|
|
" )\n",
|
|
" dataloaders[split] = data_loader\n",
|
|
" dfs[split] = pd.read_csv(split_set_path)\n",
|
|
"\n",
|
|
"with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n",
|
|
" id_to_label = json.load(fid)\n",
|
|
"id_to_label = {int(key): value for key, value in id_to_label.items()}\n",
|
|
"\n",
|
|
"tsne_results, labels_results = embeddings_scatter_plot_splits(model,\n",
|
|
" dataloaders,\n",
|
|
" device,\n",
|
|
" id_to_label,\n",
|
|
" perplexity=40,\n",
|
|
" n_iter=1000)\n",
|
|
"\n",
|
|
"\n",
|
|
"set_labels = list(set(next(chain(labels_results.values()))))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "3c3af5bf",
|
|
"metadata": {
|
|
"lines_to_next_cell": 0
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"1220"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"dfs = {}\n",
|
|
"for split in splits:\n",
|
|
" split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n",
|
|
" df = pd.read_csv(split_set_path)\n",
|
|
" df['tsne_x'] = tsne_results[split][:, 0]\n",
|
|
" df['tsne_y'] = tsne_results[split][:, 1]\n",
|
|
" df['split'] = split\n",
|
|
" # if SL_DATASET == 'wlasl':\n",
|
|
" # df['video_fn'] = df['video_id'].apply(lambda video_id: os.path.join(BASE_DATA_FOLDER, f'wlasl/videos/{video_id:05d}.mp4'))\n",
|
|
" # else:\n",
|
|
" # df['video_fn'] = df['video_id'].apply(lambda video_id: os.path.join(BASE_DATA_FOLDER, f'lsa/videos/{video_id}.mp4'))\n",
|
|
" dfs[split] = df\n",
|
|
"\n",
|
|
"df = pd.concat([dfs['train'].sample(20), dfs['val']]).reset_index(drop=True)\n",
|
|
"len(df)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 25,
|
|
"id": "dccbe1b9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tqdm.auto import tqdm\n",
|
|
"\n",
|
|
"def load_videos(video_list):\n",
|
|
" print('loading videos')\n",
|
|
" videos = []\n",
|
|
" for video_fn in tqdm(video_list):\n",
|
|
" if video_fn is None:\n",
|
|
" video_data = None\n",
|
|
" else:\n",
|
|
" with open(video_fn, 'rb') as fid:\n",
|
|
" video_data = base64.b64encode(fid.read()).decode()\n",
|
|
" videos.append(video_data)\n",
|
|
" print('Done loading videos')\n",
|
|
" return videos"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 26,
|
|
"id": "904298f0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"use_img_div = False\n",
|
|
"if use_img_div:\n",
|
|
" # sample dataframe data to avoid overloading scatter plot with too many videos\n",
|
|
" df = df.loc[(df['tsne_x'] > 10) & (df['tsne_x'] < 20)]\n",
|
|
" df = df.loc[(df['tsne_y'] > 10) & (df['tsne_y'] < 20)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 27,
|
|
"id": "42832f7c",
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<div class=\"bk-root\">\n",
|
|
" <a href=\"https://bokeh.org\" target=\"_blank\" class=\"bk-logo bk-logo-small bk-logo-notebook\"></a>\n",
|
|
" <span id=\"1144\">Loading BokehJS ...</span>\n",
|
|
" </div>\n"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"data": {
|
|
"application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n // Clean up Bokeh references\n if (id != null && id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim();\n if (id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"<div style='background-color: #fdd'>\\n\"+\n \"<p>\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"</p>\\n\"+\n \"<ul>\\n\"+\n \"<li>re-rerun `output_notebook()` to attempt to load from CDN again, or</li>\\n\"+\n \"<li>use INLINE resources instead, as so:</li>\\n\"+\n \"</ul>\\n\"+\n \"<code>\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"</code>\\n\"+\n \"</div>\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1144\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1144\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));",
|
|
"application/vnd.bokehjs_load.v0+json": ""
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"img_div = '''\n",
|
|
" <div>\n",
|
|
" <video autoplay src=\"data:video/mp4;base64,@videos\" height=\"90\" width=\"120\">\n",
|
|
" </video>\n",
|
|
" </div>\n",
|
|
"'''\n",
|
|
"TOOLTIPS = f\"\"\"\n",
|
|
" <div>\n",
|
|
" {img_div if use_img_div else ''}\n",
|
|
" <div>\n",
|
|
" <span style=\"font-size: 17px; font-weight: bold;\">@label_desc - @split</span>\n",
|
|
" <span style=\"font-size: 15px; color: #966;\">[#@video_id]</span>\n",
|
|
" </div>\n",
|
|
" </div>\n",
|
|
" </div>\n",
|
|
"\"\"\"\n",
|
|
"cmap = LinearColorMapper(palette=\"Turbo256\", low=0, high=len(set_labels))\n",
|
|
"\n",
|
|
"output_notebook()\n",
|
|
"# or \n",
|
|
"# output_file(\"scatter_plot.html\")\n",
|
|
"\n",
|
|
"p = figure(width=1000,\n",
|
|
" height=800,\n",
|
|
" tooltips=TOOLTIPS,\n",
|
|
" title=f\"Check {'video' if use_img_div else 'label'} by hovering mouse over the dots\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"id": "ead4daf7",
|
|
"metadata": {
|
|
"scrolled": false
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "KeyError",
|
|
"evalue": "'label_name'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
|
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py:3802\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3801\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 3802\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_engine\u001b[39m.\u001b[39;49mget_loc(casted_key)\n\u001b[1;32m 3803\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n",
|
|
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/_libs/index.pyx:138\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
|
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/_libs/index.pyx:165\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
|
|
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5745\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
|
"File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5753\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n",
|
|
"\u001b[0;31mKeyError\u001b[0m: 'label_name'",
|
|
"\nThe above exception was the direct cause of the following exception:\n",
|
|
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[29], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m column_data \u001b[39m=\u001b[39m \u001b[39mdict\u001b[39m(\n\u001b[1;32m 2\u001b[0m x\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39mtsne_x\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[1;32m 3\u001b[0m y\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39mtsne_y\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[1;32m 4\u001b[0m label\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39msign\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[0;32m----> 5\u001b[0m label_desc\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39;49m\u001b[39mlabel_name\u001b[39;49m\u001b[39m'\u001b[39;49m],\n\u001b[1;32m 6\u001b[0m split\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39msplit\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[1;32m 7\u001b[0m video_id\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39mvideo_id\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[1;32m 8\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[39mif\u001b[39;00m use_img_div:\n\u001b[1;32m 11\u001b[0m emb_videos \u001b[39m=\u001b[39m load_videos(df[\u001b[39m'\u001b[39m\u001b[39mvideo_fn\u001b[39m\u001b[39m'\u001b[39m])\n",
|
|
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/frame.py:3807\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3805\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns\u001b[39m.\u001b[39mnlevels \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 3806\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 3807\u001b[0m indexer \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcolumns\u001b[39m.\u001b[39;49mget_loc(key)\n\u001b[1;32m 3808\u001b[0m \u001b[39mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 3809\u001b[0m indexer \u001b[39m=\u001b[39m [indexer]\n",
|
|
"File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py:3804\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3802\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_engine\u001b[39m.\u001b[39mget_loc(casted_key)\n\u001b[1;32m 3803\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n\u001b[0;32m-> 3804\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(key) \u001b[39mfrom\u001b[39;00m \u001b[39merr\u001b[39;00m\n\u001b[1;32m 3805\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m 3806\u001b[0m \u001b[39m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3807\u001b[0m \u001b[39m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3808\u001b[0m \u001b[39m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3809\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_indexing_error(key)\n",
|
|
"\u001b[0;31mKeyError\u001b[0m: 'label_name'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"column_data = dict(\n",
|
|
" x=df['tsne_x'],\n",
|
|
" y=df['tsne_y'],\n",
|
|
" label=df['labels'],\n",
|
|
" label_desc=df['sign'],\n",
|
|
" split=df['split'],\n",
|
|
" video_id=df['video_id']\n",
|
|
")\n",
|
|
"\n",
|
|
"if use_img_div:\n",
|
|
" emb_videos = load_videos(df['video_fn'])\n",
|
|
" column_data[\"videos\"] = emb_videos\n",
|
|
"source = ColumnDataSource(data=column_data)\n",
|
|
"\n",
|
|
"p.scatter('x', 'y',\n",
|
|
" size=10,\n",
|
|
" source=source,\n",
|
|
" fill_color={\"field\": 'label', \"transform\": cmap},\n",
|
|
" line_color={\"field\": 'label', \"transform\": cmap}, \n",
|
|
" #legend_label={\"field\": 'split', \"transform\": lambda x: df['split']},\n",
|
|
"# marker={\"field\": 'split'}\n",
|
|
" )\n",
|
|
"\n",
|
|
"show(p)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1d761766",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"df"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1c73f195",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"jupytext": {
|
|
"cell_metadata_filter": "-all",
|
|
"main_language": "python",
|
|
"notebook_metadata_filter": "-all"
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.10"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|