Updated some stuff

This commit is contained in:
2024-03-20 22:14:18 +01:00
parent acaa8ff054
commit dad64d00be
7 changed files with 105 additions and 75 deletions

View File

@@ -85,7 +85,7 @@ time_embedding = TimeEmbedding(
non_linear_model = NonLinearRegression(
time_embedding.output_dim(inputDim),
len(quantiles),
len(quantiles) * 96,
hiddenSize=model_parameters["hidden_size"],
numLayers=model_parameters["num_layers"],
dropout=model_parameters["dropout"],
@@ -94,7 +94,7 @@ non_linear_model = NonLinearRegression(
# linear_model = LinearRegression(time_embedding.output_dim(inputDim), len(quantiles))
model = nn.Sequential(time_embedding, non_linear_model)
model.output_size = 1
model.output_size = 96
optimizer = torch.optim.Adam(model.parameters(), lr=model_parameters["learning_rate"])
### Policy Evaluator ###