From 4ad3336b98d4817d0b26d75ec79871811bdcc42f Mon Sep 17 00:00:00 2001 From: Victor Mylle Date: Wed, 21 Feb 2024 18:13:51 +0100 Subject: [PATCH] Set training script to execute remotely --- src/training_scripts/autoregressive_quantiles.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/training_scripts/autoregressive_quantiles.py b/src/training_scripts/autoregressive_quantiles.py index c323c12..1fac750 100644 --- a/src/training_scripts/autoregressive_quantiles.py +++ b/src/training_scripts/autoregressive_quantiles.py @@ -15,7 +15,7 @@ from src.models.time_embedding_layer import TimeEmbedding #### ClearML #### clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast") -task = clearml_helper.get_task(task_name="Autoregressive Quantile Regression: Linear + Quarter + DoW + Load + Wind + Net") +task = clearml_helper.get_task(task_name="Autoregressive Quantile Regression: Non Linear") #### Data Processor #### @@ -59,11 +59,11 @@ else: quantiles = eval(quantiles) model_parameters = { - "learning_rate": 0.0001, + "learning_rate": 0.001, "hidden_size": 512, - "num_layers": 2, + "num_layers": 4, "dropout": 0.2, - "time_feature_embedding": 4, + "time_feature_embedding": 8, } model_parameters = task.connect(model_parameters, name="model_parameters") @@ -96,9 +96,9 @@ trainer = AutoRegressiveQuantileTrainer( trainer.add_metrics_to_track( [PinballLoss(quantiles), MSELoss(), L1Loss(), CRPSLoss(quantiles)] ) -trainer.early_stopping(patience=10) +trainer.early_stopping(patience=15) trainer.plot_every(5) -trainer.train(task=task, epochs=epochs, remotely=False) +trainer.train(task=task, epochs=epochs, remotely=True) ### Policy Evaluation ### idx_samples = trainer.test_set_samples