Added LSTM model
This commit is contained in:
@@ -5,7 +5,6 @@ from clearml.automation.optuna import OptimizerOptuna
|
||||
from clearml.automation import (
|
||||
DiscreteParameterRange, HyperParameterOptimizer, RandomSearch,
|
||||
UniformIntegerParameterRange)
|
||||
from src.data.preprocessing import DataConfig
|
||||
|
||||
# trying to load Bayesian optimizer package
|
||||
try:
|
||||
@@ -21,17 +20,28 @@ except ImportError as ex:
|
||||
'we will be using RandomSearch strategy instead')
|
||||
aSearchStrategy = RandomSearch
|
||||
|
||||
# input task id to optimize
|
||||
input_task_id = input("Please enter the task id to optimize: ")
|
||||
# input task id to optimize using argparse
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task_id", help="Task ID to optimize", type=str)
|
||||
args = parser.parse_args()
|
||||
input_task_id = args.task_id
|
||||
|
||||
# check if task id is valid
|
||||
if not Task.get_task(task_id=input_task_id):
|
||||
raise ValueError("Invalid task id")
|
||||
|
||||
task = Task.init(project_name='Hyper-Parameter Optimization',
|
||||
task_name='Automatic Hyper-Parameter Optimization',
|
||||
Task.add_requirements("requirements.txt")
|
||||
Task.ignore_requirements("torch")
|
||||
Task.ignore_requirements("torchvision")
|
||||
Task.ignore_requirements("tensorboard")
|
||||
task = Task.init(project_name='Thesis/NrvForecast',
|
||||
task_name='Autoregressive Quantile Regression Hyper-Parameter Optimization',
|
||||
task_type=Task.TaskTypes.optimizer,
|
||||
reuse_last_task_id=False)
|
||||
task.set_base_docker(f"docker.io/clearml/pytorch-cuda-gcc:2.0.0-cuda11.7-cudnn8-runtime --env GIT_SSL_NO_VERIFY=true --env CLEARML_AGENT_GIT_USER=VictorMylle --env CLEARML_AGENT_GIT_PASS=Voetballer1" )
|
||||
task.set_packages("requirements.txt")
|
||||
|
||||
|
||||
execution_queue = "default"
|
||||
|
||||
@@ -40,36 +50,42 @@ execution_queue = "default"
|
||||
#### Quantiles ####
|
||||
quantile_lists = [
|
||||
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], # Deciles
|
||||
[0.25, 0.5, 0.75], # Quartiles
|
||||
[0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95], # 10% Increments, Excluding Extremes
|
||||
[0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99], # Combining Deciles with Extremes
|
||||
[0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], # Including 0 and 1
|
||||
[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], # Mixed Small and Large Increments
|
||||
[0.2, 0.4, 0.6, 0.8], # 20% Increments
|
||||
[0.125, 0.375, 0.625, 0.875], # Eighths
|
||||
[0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90], # 10% Increments
|
||||
[0.01, 0.02, 0.03, 0.04, 0.05, 0.1, 0.15, 0.2, 0.3, 0.5] # Mixed Fine and Coarser Increments
|
||||
[0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95], # 10% Increments
|
||||
]
|
||||
|
||||
|
||||
quantiles_range = DiscreteParameterRange("general/quantiles", values=quantile_lists)
|
||||
|
||||
#### Data Config ####
|
||||
quarter_range = DiscreteParameterRange("data_features/quarter", values=[True, False])
|
||||
day_of_week_range = DiscreteParameterRange("data_features/day_of_week", values=[True, False])
|
||||
|
||||
load_forecast_range = DiscreteParameterRange("data_features/load_forecast", values=[True, False])
|
||||
load_history_range = DiscreteParameterRange("data_features/load_history", values=[True, False])
|
||||
|
||||
### OPTIMIZER OBJECT ###
|
||||
optimizer = HyperParameterOptimizer(
|
||||
base_task_id=input_task_id,
|
||||
objective_metric_title="PinballLoss",
|
||||
objective_metric_series="test",
|
||||
objective_metric_title="Summary",
|
||||
objective_metric_series="test_CRPSLoss",
|
||||
objective_metric_sign="min",
|
||||
execution_queue=execution_queue,
|
||||
max_number_of_concurrent_tasks=1,
|
||||
optimizer_class=aSearchStrategy,
|
||||
max_iteration_per_job=50,
|
||||
# save_top_k_tasks_only=3,
|
||||
pool_period_min=0.2,
|
||||
total_max_jobs=15,
|
||||
|
||||
hyper_parameters=[
|
||||
quantiles_range,
|
||||
quarter_range,
|
||||
day_of_week_range,
|
||||
load_forecast_range,
|
||||
load_history_range
|
||||
]
|
||||
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user