Improved policy executer
This commit is contained in:
@@ -13,7 +13,7 @@ from src.models.time_embedding_layer import TimeEmbedding
|
||||
|
||||
#### ClearML ####
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(task_name="Autoregressive Quantile Regression: GRU + Quarter + DoW + Load + Wind + Net")
|
||||
task = clearml_helper.get_task(task_name="Autoregressive Quantile Regression: Non Linear + Quarter + DoW + Load + Wind + Net")
|
||||
|
||||
|
||||
#### Data Processor ####
|
||||
@@ -35,7 +35,7 @@ data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
data_config = task.connect(data_config, name="data_features")
|
||||
|
||||
data_processor = DataProcessor(data_config, path="", lstm=True)
|
||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||
data_processor.set_batch_size(512)
|
||||
data_processor.set_full_day_skip(False)
|
||||
|
||||
@@ -67,9 +67,10 @@ model_parameters = {
|
||||
model_parameters = task.connect(model_parameters, name="model_parameters")
|
||||
|
||||
time_embedding = TimeEmbedding(data_processor.get_time_feature_size(), model_parameters["time_feature_embedding"])
|
||||
lstm_model = GRUModel(time_embedding.output_dim(inputDim), len(quantiles), hidden_size=model_parameters["hidden_size"], num_layers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
||||
# lstm_model = GRUModel(time_embedding.output_dim(inputDim), len(quantiles), hidden_size=model_parameters["hidden_size"], num_layers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
||||
non_linear_model = NonLinearRegression(time_embedding.output_dim(inputDim), len(quantiles), hiddenSize=model_parameters["hidden_size"], numLayers=model_parameters["num_layers"], dropout=model_parameters["dropout"])
|
||||
|
||||
model = nn.Sequential(time_embedding, lstm_model)
|
||||
model = nn.Sequential(time_embedding, non_linear_model)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=model_parameters["learning_rate"])
|
||||
|
||||
#### Trainer ####
|
||||
|
||||
Reference in New Issue
Block a user