{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "8ef5cd92", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "78c4643a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "import os.path as op\n", "import pandas as pd\n", "import json\n", "import base64" ] }, { "cell_type": "code", "execution_count": 3, "id": "ffba4333", "metadata": {}, "outputs": [], "source": [ "sys.path.append(op.abspath('..'))" ] }, { "cell_type": "code", "execution_count": 4, "id": "5bc81f71", "metadata": {}, "outputs": [], "source": [ "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"" ] }, { "cell_type": "code", "execution_count": 5, "id": "3de8bcf2", "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "import torch\n", "import multiprocessing\n", "from itertools import chain\n", "import numpy as np\n", "import random" ] }, { "cell_type": "code", "execution_count": 6, "id": "91a045ba", "metadata": {}, "outputs": [], "source": [ "from bokeh.io import output_notebook, output_file\n", "from bokeh.plotting import figure, show\n", "from bokeh.models import LinearColorMapper, ColumnDataSource\n", "from bokeh.transform import factor_cmap, factor_mark\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "from datasets import SLREmbeddingDataset, collate_fn_padd\n", "from datasets.dataset_loader import LocalDatasetLoader\n", "from models import embeddings_scatter_plot_splits\n", "from models import SPOTER_EMBEDDINGS" ] }, { "cell_type": "code", "execution_count": 7, "id": "bc50c296", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seed = 43\n", "random.seed(seed)\n", "np.random.seed(seed)\n", "os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", "torch.manual_seed(seed)\n", "torch.cuda.manual_seed(seed)\n", "torch.cuda.manual_seed_all(seed)\n", "torch.backends.cudnn.deterministic = True\n", "torch.use_deterministic_algorithms(True) \n", "generator = torch.Generator()\n", "generator.manual_seed(seed)" ] }, { "cell_type": "code", "execution_count": 8, "id": "82766a17", "metadata": {}, "outputs": [], "source": [ "BASE_DATA_FOLDER = '../data/'\n", "os.environ[\"BASE_DATA_FOLDER\"] = BASE_DATA_FOLDER\n", "device = torch.device(\"cpu\")\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "ead15a36", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# LOAD MODEL FROM CLEARML\n", "# from clearml import InputModel\n", "# model = InputModel(model_id='1b736da469b04e91b8451d2342aef6ce')\n", "# checkpoint = torch.load(model.get_weights())\n", "\n", "\n", "CHECKPOINT_PATH = \"../out-checkpoints/augment_rotate_75_x8/checkpoint_embed_18.pth\"\n", "checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)\n", "\n", "\n", "model = SPOTER_EMBEDDINGS(\n", " features=checkpoint[\"config_args\"].vector_length,\n", " hidden_dim=checkpoint[\"config_args\"].hidden_dim,\n", " norm_emb=checkpoint[\"config_args\"].normalize_embeddings,\n", ").to(device)\n", "\n", "model.load_state_dict(checkpoint[\"state_dict\"])" ] }, { "cell_type": "code", "execution_count": 18, "id": "20f8036d", "metadata": {}, "outputs": [], "source": [ "SL_DATASET = 'wlasl' # or 'lsa'\n", "if SL_DATASET == 'wlasl':\n", " dataset_name = \"processed\"\n", " num_classes = 15\n", " split_dataset_path = \"spoter_test.csv\"\n", "else:\n", " dataset_name = \"lsa64_mapped_mediapipe_only_landmarks_25fps\"\n", " num_classes = 64\n", " split_dataset_path = \"LSA64_{}.csv\"\n", " " ] }, { "cell_type": "code", "execution_count": 19, "id": "758716b6", "metadata": {}, "outputs": [], "source": [ "def get_dataset_loader(loader_name=None):\n", " if loader_name == 'CLEARML':\n", " from datasets.clearml_dataset_loader import ClearMLDatasetLoader\n", " return ClearMLDatasetLoader()\n", " else:\n", " return LocalDatasetLoader()\n", "\n", "dataset_loader = get_dataset_loader()\n", "dataset_project = \"Sign Language Recognition\"\n", "batch_size = 1\n", "dataset_folder = dataset_loader.get_dataset_folder(dataset_project, dataset_name)" ] }, { "cell_type": "code", "execution_count": 20, "id": "f1527959", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.8/dist-packages/sklearn/manifold/_t_sne.py:780: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.\n", " warnings.warn(\n", "/usr/local/lib/python3.8/dist-packages/sklearn/manifold/_t_sne.py:790: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.\n", " warnings.warn(\n" ] } ], "source": [ "dataloaders = {}\n", "splits = ['train', 'val']\n", "dfs = {}\n", "for split in splits:\n", " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", " split_set = SLREmbeddingDataset(split_set_path, triplet=False)\n", " data_loader = DataLoader(\n", " split_set,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " collate_fn=collate_fn_padd,\n", " pin_memory=torch.cuda.is_available(),\n", " num_workers=multiprocessing.cpu_count()\n", " )\n", " dataloaders[split] = data_loader\n", " dfs[split] = pd.read_csv(split_set_path)\n", "\n", "with open(op.join(dataset_folder, 'id_to_label.json')) as fid:\n", " id_to_label = json.load(fid)\n", "id_to_label = {int(key): value for key, value in id_to_label.items()}\n", "\n", "tsne_results, labels_results = embeddings_scatter_plot_splits(model,\n", " dataloaders,\n", " device,\n", " id_to_label,\n", " perplexity=40,\n", " n_iter=1000)\n", "\n", "\n", "set_labels = list(set(next(chain(labels_results.values()))))" ] }, { "cell_type": "code", "execution_count": 24, "id": "3c3af5bf", "metadata": { "lines_to_next_cell": 0 }, "outputs": [ { "data": { "text/plain": [ "1220" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dfs = {}\n", "for split in splits:\n", " split_set_path = op.join(dataset_folder, split_dataset_path.format(split))\n", " df = pd.read_csv(split_set_path)\n", " df['tsne_x'] = tsne_results[split][:, 0]\n", " df['tsne_y'] = tsne_results[split][:, 1]\n", " df['split'] = split\n", " # if SL_DATASET == 'wlasl':\n", " # df['video_fn'] = df['video_id'].apply(lambda video_id: os.path.join(BASE_DATA_FOLDER, f'wlasl/videos/{video_id:05d}.mp4'))\n", " # else:\n", " # df['video_fn'] = df['video_id'].apply(lambda video_id: os.path.join(BASE_DATA_FOLDER, f'lsa/videos/{video_id}.mp4'))\n", " dfs[split] = df\n", "\n", "df = pd.concat([dfs['train'].sample(20), dfs['val']]).reset_index(drop=True)\n", "len(df)" ] }, { "cell_type": "code", "execution_count": 25, "id": "dccbe1b9", "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "def load_videos(video_list):\n", " print('loading videos')\n", " videos = []\n", " for video_fn in tqdm(video_list):\n", " if video_fn is None:\n", " video_data = None\n", " else:\n", " with open(video_fn, 'rb') as fid:\n", " video_data = base64.b64encode(fid.read()).decode()\n", " videos.append(video_data)\n", " print('Done loading videos')\n", " return videos" ] }, { "cell_type": "code", "execution_count": 26, "id": "904298f0", "metadata": {}, "outputs": [], "source": [ "use_img_div = False\n", "if use_img_div:\n", " # sample dataframe data to avoid overloading scatter plot with too many videos\n", " df = df.loc[(df['tsne_x'] > 10) & (df['tsne_x'] < 20)]\n", " df = df.loc[(df['tsne_y'] > 10) & (df['tsne_y'] < 20)]" ] }, { "cell_type": "code", "execution_count": 27, "id": "42832f7c", "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", " \n", " Loading BokehJS ...\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/javascript": "(function(root) {\n function now() {\n return new Date();\n }\n\n const force = true;\n\n if (typeof root._bokeh_onload_callbacks === \"undefined\" || force === true) {\n root._bokeh_onload_callbacks = [];\n root._bokeh_is_loading = undefined;\n }\n\nconst JS_MIME_TYPE = 'application/javascript';\n const HTML_MIME_TYPE = 'text/html';\n const EXEC_MIME_TYPE = 'application/vnd.bokehjs_exec.v0+json';\n const CLASS_NAME = 'output_bokeh rendered_html';\n\n /**\n * Render data to the DOM node\n */\n function render(props, node) {\n const script = document.createElement(\"script\");\n node.appendChild(script);\n }\n\n /**\n * Handle when an output is cleared or removed\n */\n function handleClearOutput(event, handle) {\n const cell = handle.cell;\n\n const id = cell.output_area._bokeh_element_id;\n const server_id = cell.output_area._bokeh_server_id;\n // Clean up Bokeh references\n if (id != null && id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n\n if (server_id !== undefined) {\n // Clean up Bokeh references\n const cmd_clean = \"from bokeh.io.state import curstate; print(curstate().uuid_to_server['\" + server_id + \"'].get_sessions()[0].document.roots[0]._id)\";\n cell.notebook.kernel.execute(cmd_clean, {\n iopub: {\n output: function(msg) {\n const id = msg.content.text.trim();\n if (id in Bokeh.index) {\n Bokeh.index[id].model.document.clear();\n delete Bokeh.index[id];\n }\n }\n }\n });\n // Destroy server and session\n const cmd_destroy = \"import bokeh.io.notebook as ion; ion.destroy_server('\" + server_id + \"')\";\n cell.notebook.kernel.execute(cmd_destroy);\n }\n }\n\n /**\n * Handle when a new output is added\n */\n function handleAddOutput(event, handle) {\n const output_area = handle.output_area;\n const output = handle.output;\n\n // limit handleAddOutput to display_data with EXEC_MIME_TYPE content only\n if ((output.output_type != \"display_data\") || (!Object.prototype.hasOwnProperty.call(output.data, EXEC_MIME_TYPE))) {\n return\n }\n\n const toinsert = output_area.element.find(\".\" + CLASS_NAME.split(' ')[0]);\n\n if (output.metadata[EXEC_MIME_TYPE][\"id\"] !== undefined) {\n toinsert[toinsert.length - 1].firstChild.textContent = output.data[JS_MIME_TYPE];\n // store reference to embed id on output_area\n output_area._bokeh_element_id = output.metadata[EXEC_MIME_TYPE][\"id\"];\n }\n if (output.metadata[EXEC_MIME_TYPE][\"server_id\"] !== undefined) {\n const bk_div = document.createElement(\"div\");\n bk_div.innerHTML = output.data[HTML_MIME_TYPE];\n const script_attrs = bk_div.children[0].attributes;\n for (let i = 0; i < script_attrs.length; i++) {\n toinsert[toinsert.length - 1].firstChild.setAttribute(script_attrs[i].name, script_attrs[i].value);\n toinsert[toinsert.length - 1].firstChild.textContent = bk_div.children[0].textContent\n }\n // store reference to server id on output_area\n output_area._bokeh_server_id = output.metadata[EXEC_MIME_TYPE][\"server_id\"];\n }\n }\n\n function register_renderer(events, OutputArea) {\n\n function append_mime(data, metadata, element) {\n // create a DOM node to render to\n const toinsert = this.create_output_subarea(\n metadata,\n CLASS_NAME,\n EXEC_MIME_TYPE\n );\n this.keyboard_manager.register_events(toinsert);\n // Render to node\n const props = {data: data, metadata: metadata[EXEC_MIME_TYPE]};\n render(props, toinsert[toinsert.length - 1]);\n element.append(toinsert);\n return toinsert\n }\n\n /* Handle when an output is cleared or removed */\n events.on('clear_output.CodeCell', handleClearOutput);\n events.on('delete.Cell', handleClearOutput);\n\n /* Handle when a new output is added */\n events.on('output_added.OutputArea', handleAddOutput);\n\n /**\n * Register the mime type and append_mime function with output_area\n */\n OutputArea.prototype.register_mime_type(EXEC_MIME_TYPE, append_mime, {\n /* Is output safe? */\n safe: true,\n /* Index of renderer in `output_area.display_order` */\n index: 0\n });\n }\n\n // register the mime type if in Jupyter Notebook environment and previously unregistered\n if (root.Jupyter !== undefined) {\n const events = require('base/js/events');\n const OutputArea = require('notebook/js/outputarea').OutputArea;\n\n if (OutputArea.prototype.mime_types().indexOf(EXEC_MIME_TYPE) == -1) {\n register_renderer(events, OutputArea);\n }\n }\n if (typeof (root._bokeh_timeout) === \"undefined\" || force === true) {\n root._bokeh_timeout = Date.now() + 5000;\n root._bokeh_failed_load = false;\n }\n\n const NB_LOAD_WARNING = {'data': {'text/html':\n \"
\\n\"+\n \"

\\n\"+\n \"BokehJS does not appear to have successfully loaded. If loading BokehJS from CDN, this \\n\"+\n \"may be due to a slow or bad network connection. Possible fixes:\\n\"+\n \"

\\n\"+\n \"
    \\n\"+\n \"
  • re-rerun `output_notebook()` to attempt to load from CDN again, or
  • \\n\"+\n \"
  • use INLINE resources instead, as so:
  • \\n\"+\n \"
\\n\"+\n \"\\n\"+\n \"from bokeh.resources import INLINE\\n\"+\n \"output_notebook(resources=INLINE)\\n\"+\n \"\\n\"+\n \"
\"}};\n\n function display_loaded() {\n const el = document.getElementById(\"1144\");\n if (el != null) {\n el.textContent = \"BokehJS is loading...\";\n }\n if (root.Bokeh !== undefined) {\n if (el != null) {\n el.textContent = \"BokehJS \" + root.Bokeh.version + \" successfully loaded.\";\n }\n } else if (Date.now() < root._bokeh_timeout) {\n setTimeout(display_loaded, 100)\n }\n }\n\n function run_callbacks() {\n try {\n root._bokeh_onload_callbacks.forEach(function(callback) {\n if (callback != null)\n callback();\n });\n } finally {\n delete root._bokeh_onload_callbacks\n }\n console.debug(\"Bokeh: all callbacks have finished\");\n }\n\n function load_libs(css_urls, js_urls, callback) {\n if (css_urls == null) css_urls = [];\n if (js_urls == null) js_urls = [];\n\n root._bokeh_onload_callbacks.push(callback);\n if (root._bokeh_is_loading > 0) {\n console.debug(\"Bokeh: BokehJS is being loaded, scheduling callback at\", now());\n return null;\n }\n if (js_urls == null || js_urls.length === 0) {\n run_callbacks();\n return null;\n }\n console.debug(\"Bokeh: BokehJS not loaded, scheduling load and callback at\", now());\n root._bokeh_is_loading = css_urls.length + js_urls.length;\n\n function on_load() {\n root._bokeh_is_loading--;\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: all BokehJS libraries/stylesheets loaded\");\n run_callbacks()\n }\n }\n\n function on_error(url) {\n console.error(\"failed to load \" + url);\n }\n\n for (let i = 0; i < css_urls.length; i++) {\n const url = css_urls[i];\n const element = document.createElement(\"link\");\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.rel = \"stylesheet\";\n element.type = \"text/css\";\n element.href = url;\n console.debug(\"Bokeh: injecting link tag for BokehJS stylesheet: \", url);\n document.body.appendChild(element);\n }\n\n for (let i = 0; i < js_urls.length; i++) {\n const url = js_urls[i];\n const element = document.createElement('script');\n element.onload = on_load;\n element.onerror = on_error.bind(null, url);\n element.async = false;\n element.src = url;\n console.debug(\"Bokeh: injecting script tag for BokehJS library: \", url);\n document.head.appendChild(element);\n }\n };\n\n function inject_raw_css(css) {\n const element = document.createElement(\"style\");\n element.appendChild(document.createTextNode(css));\n document.body.appendChild(element);\n }\n\n const js_urls = [\"https://cdn.bokeh.org/bokeh/release/bokeh-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-gl-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-widgets-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-tables-2.4.3.min.js\", \"https://cdn.bokeh.org/bokeh/release/bokeh-mathjax-2.4.3.min.js\"];\n const css_urls = [];\n\n const inline_js = [ function(Bokeh) {\n Bokeh.set_log_level(\"info\");\n },\nfunction(Bokeh) {\n }\n ];\n\n function run_inline_js() {\n if (root.Bokeh !== undefined || force === true) {\n for (let i = 0; i < inline_js.length; i++) {\n inline_js[i].call(root, root.Bokeh);\n }\nif (force === true) {\n display_loaded();\n }} else if (Date.now() < root._bokeh_timeout) {\n setTimeout(run_inline_js, 100);\n } else if (!root._bokeh_failed_load) {\n console.log(\"Bokeh: BokehJS failed to load within specified timeout.\");\n root._bokeh_failed_load = true;\n } else if (force !== true) {\n const cell = $(document.getElementById(\"1144\")).parents('.cell').data().cell;\n cell.output_area.append_execute_result(NB_LOAD_WARNING)\n }\n }\n\n if (root._bokeh_is_loading === 0) {\n console.debug(\"Bokeh: BokehJS loaded, going straight to plotting\");\n run_inline_js();\n } else {\n load_libs(css_urls, js_urls, function() {\n console.debug(\"Bokeh: BokehJS plotting callback run at\", now());\n run_inline_js();\n });\n }\n}(window));", "application/vnd.bokehjs_load.v0+json": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "img_div = '''\n", "
\n", " \n", "
\n", "'''\n", "TOOLTIPS = f\"\"\"\n", "
\n", " {img_div if use_img_div else ''}\n", "
\n", " @label_desc - @split\n", " [#@video_id]\n", "
\n", "
\n", " \n", "\"\"\"\n", "cmap = LinearColorMapper(palette=\"Turbo256\", low=0, high=len(set_labels))\n", "\n", "output_notebook()\n", "# or \n", "# output_file(\"scatter_plot.html\")\n", "\n", "p = figure(width=1000,\n", " height=800,\n", " tooltips=TOOLTIPS,\n", " title=f\"Check {'video' if use_img_div else 'label'} by hovering mouse over the dots\")" ] }, { "cell_type": "code", "execution_count": 29, "id": "ead4daf7", "metadata": { "scrolled": false }, "outputs": [ { "ename": "KeyError", "evalue": "'label_name'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py:3802\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3801\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 3802\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_engine\u001b[39m.\u001b[39;49mget_loc(casted_key)\n\u001b[1;32m 3803\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n", "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/_libs/index.pyx:138\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/_libs/index.pyx:165\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5745\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:5753\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", "\u001b[0;31mKeyError\u001b[0m: 'label_name'", "\nThe above exception was the direct cause of the following exception:\n", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[29], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m column_data \u001b[39m=\u001b[39m \u001b[39mdict\u001b[39m(\n\u001b[1;32m 2\u001b[0m x\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39mtsne_x\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[1;32m 3\u001b[0m y\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39mtsne_y\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[1;32m 4\u001b[0m label\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39msign\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[0;32m----> 5\u001b[0m label_desc\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39;49m\u001b[39mlabel_name\u001b[39;49m\u001b[39m'\u001b[39;49m],\n\u001b[1;32m 6\u001b[0m split\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39msplit\u001b[39m\u001b[39m'\u001b[39m],\n\u001b[1;32m 7\u001b[0m video_id\u001b[39m=\u001b[39mdf[\u001b[39m'\u001b[39m\u001b[39mvideo_id\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[1;32m 8\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[39mif\u001b[39;00m use_img_div:\n\u001b[1;32m 11\u001b[0m emb_videos \u001b[39m=\u001b[39m load_videos(df[\u001b[39m'\u001b[39m\u001b[39mvideo_fn\u001b[39m\u001b[39m'\u001b[39m])\n", "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/frame.py:3807\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3805\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcolumns\u001b[39m.\u001b[39mnlevels \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m:\n\u001b[1;32m 3806\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 3807\u001b[0m indexer \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcolumns\u001b[39m.\u001b[39;49mget_loc(key)\n\u001b[1;32m 3808\u001b[0m \u001b[39mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 3809\u001b[0m indexer \u001b[39m=\u001b[39m [indexer]\n", "File \u001b[0;32m/usr/local/lib/python3.8/dist-packages/pandas/core/indexes/base.py:3804\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key, method, tolerance)\u001b[0m\n\u001b[1;32m 3802\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_engine\u001b[39m.\u001b[39mget_loc(casted_key)\n\u001b[1;32m 3803\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mKeyError\u001b[39;00m \u001b[39mas\u001b[39;00m err:\n\u001b[0;32m-> 3804\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mKeyError\u001b[39;00m(key) \u001b[39mfrom\u001b[39;00m \u001b[39merr\u001b[39;00m\n\u001b[1;32m 3805\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m 3806\u001b[0m \u001b[39m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3807\u001b[0m \u001b[39m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3808\u001b[0m \u001b[39m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3809\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_indexing_error(key)\n", "\u001b[0;31mKeyError\u001b[0m: 'label_name'" ] } ], "source": [ "column_data = dict(\n", " x=df['tsne_x'],\n", " y=df['tsne_y'],\n", " label=df['labels'],\n", " label_desc=df['sign'],\n", " split=df['split'],\n", " video_id=df['video_id']\n", ")\n", "\n", "if use_img_div:\n", " emb_videos = load_videos(df['video_fn'])\n", " column_data[\"videos\"] = emb_videos\n", "source = ColumnDataSource(data=column_data)\n", "\n", "p.scatter('x', 'y',\n", " size=10,\n", " source=source,\n", " fill_color={\"field\": 'label', \"transform\": cmap},\n", " line_color={\"field\": 'label', \"transform\": cmap}, \n", " #legend_label={\"field\": 'split', \"transform\": lambda x: df['split']},\n", "# marker={\"field\": 'split'}\n", " )\n", "\n", "show(p)" ] }, { "cell_type": "code", "execution_count": null, "id": "1d761766", "metadata": {}, "outputs": [], "source": [ "df" ] }, { "cell_type": "code", "execution_count": null, "id": "1c73f195", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }