Updated training notebook

This commit is contained in:
Victor Mylle
2023-11-26 22:46:41 +00:00
parent 4886c0d9a0
commit 74c2c07b68

View File

@@ -7,7 +7,15 @@
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../..')\n",
"sys.path.append('../..')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from src.data import DataProcessor, DataConfig\n",
"from src.trainers.quantile_trainer import AutoRegressiveQuantileTrainer, NonAutoRegressiveQuantileRegression\n",
"from src.trainers.probabilistic_baseline import ProbabilisticBaselineTrainer\n",
@@ -34,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -159,7 +167,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -173,23 +181,35 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/workspaces/Thesis/src/notebooks/../../src/trainers/quantile_trainer.py:68: UserWarning:\n",
"\n",
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"\n",
"/workspaces/Thesis/src/notebooks/../../src/losses/pinball_loss.py:8: UserWarning:\n",
"\n",
"To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
"\n"
"/workspaces/Thesis/src/notebooks/../../src/trainers/quantile_trainer.py:68: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" quantiles_tensor = torch.tensor(quantiles)\n",
"/workspaces/Thesis/src/notebooks/../../src/losses/pinball_loss.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
" self.quantiles_tensor = torch.tensor(quantiles, dtype=torch.float32)\n",
"InsecureRequestWarning: Certificate verification is disabled! Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClearML Task: created new task id=cbf4a5162c604d6ea8f14e71e2d27410\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/cbf4a5162c604d6ea8f14e71e2d27410/output/log\n",
"Early stopping triggered\n"
"ClearML Task: created new task id=4652507a84f5435fb6bd98c645d15f24\n",
"ClearML results page: http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/4652507a84f5435fb6bd98c645d15f24/output/log\n",
"2023-11-26 22:15:47,860 - clearml.Task - INFO - Storing jupyter notebook directly as code\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Switching to remote execution, output log page http://192.168.1.182:8080/projects/2e46d4af6f1e4c399cf9f5aa30bc8795/experiments/4652507a84f5435fb6bd98c645d15f24/output/log\n"
]
},
{
"ename": "",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
]
}
],
@@ -226,7 +246,7 @@
")\n",
"trainer.early_stopping(patience=10)\n",
"trainer.plot_every(5)\n",
"trainer.train(epochs=epochs)"
"trainer.train(epochs=epochs, remotely=True)"
]
},
{
@@ -315,103 +335,6 @@
"trainer.plot_every(5)\n",
"trainer.train(epochs=epochs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"dataset = data_processor.get_train_dataloader().dataset\n",
"dataset.predict_sequence_length = 1\n",
"dataset.data_config.LOAD_HISTORY = True\n",
"\n",
"\n",
"def auto_regressive_batch(dataset, idx_batch, sequence_length):\n",
" target_full = [] # (batch_size, sequence_length)\n",
" predictions_samples = [] # (batch_size, sequence_length)\n",
" predictions_full = [] # (batch_size, sequence_length, quantiles)\n",
"\n",
" prev_features, targets = dataset.get_batch(idx_batch)\n",
"\n",
" initial_sequence = prev_features[:, :96]\n",
"\n",
" target_full = targets[:, 0]\n",
" self.\n",
"\n",
"\n",
"\n",
"auto_regressive_batch(dataset, [0, 1, 2], 50)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def auto_regressive(self, data_loader, idx, sequence_length: int = 96):\n",
" self.model.eval()\n",
" target_full = []\n",
" predictions_sampled = []\n",
" predictions_full = []\n",
"\n",
" prev_features, target = data_loader.dataset[idx]\n",
" prev_features = prev_features.to(self.device)\n",
"\n",
" initial_sequence = prev_features[:96]\n",
"\n",
" target_full.append(target)\n",
" with torch.no_grad():\n",
" prediction = self.model(prev_features.unsqueeze(0))\n",
" predictions_full.append(prediction.squeeze(0))\n",
"\n",
" # sample from the distribution\n",
" sample = self.sample_from_dist(\n",
" self.quantiles.cpu(), prediction.squeeze(-1).cpu().numpy()\n",
" )\n",
" predictions_sampled.append(sample)\n",
"\n",
" for i in range(sequence_length - 1):\n",
" new_features = torch.cat(\n",
" (prev_features[1:96].cpu(), torch.tensor([predictions_sampled[-1]])),\n",
" dim=0,\n",
" )\n",
" new_features = new_features.float()\n",
"\n",
" # get the other needed features\n",
" other_features, new_target = data_loader.dataset.random_day_autoregressive(\n",
" idx + i + 1\n",
" )\n",
"\n",
" if other_features is not None:\n",
" prev_features = torch.cat((new_features, other_features), dim=0)\n",
" else:\n",
" prev_features = new_features\n",
"\n",
" # add target to target_full\n",
" target_full.append(new_target)\n",
"\n",
" # predict\n",
" with torch.no_grad():\n",
" prediction = self.model(prev_features.unsqueeze(0).to(self.device))\n",
" predictions_full.append(prediction.squeeze(0))\n",
"\n",
" # sample from the distribution\n",
" sample = self.sample_from_dist(\n",
" self.quantiles.cpu(), prediction.squeeze(-1).cpu().numpy()\n",
" )\n",
" predictions_sampled.append(sample)\n",
"\n",
" return (\n",
" initial_sequence.cpu(),\n",
" torch.stack(predictions_full).cpu(),\n",
" torch.tensor(predictions_sampled).reshape(-1, 1),\n",
" torch.stack(target_full).cpu(),\n",
" )"
]
}
],
"metadata": {