Fixed diffusion confidence interval plot

This commit is contained in:
Victor Mylle
2024-02-18 16:01:18 +01:00
parent 7bd0476085
commit bd250a664b
3 changed files with 87 additions and 67 deletions

View File

@@ -38,7 +38,7 @@ data_config.NOMINAL_NET_POSITION = True
data_config = task.connect(data_config, name="data_features")
data_processor = DataProcessor(data_config, path="", lstm=False)
data_processor.set_batch_size(128)
data_processor.set_batch_size(64)
data_processor.set_full_day_skip(True)
inputDim = data_processor.get_input_size()
@@ -47,15 +47,15 @@ print("Input dim: ", inputDim)
model_parameters = {
"epochs": 5000,
"learning_rate": 0.0001,
"hidden_sizes": [512, 512, 512],
"time_dim": 64,
"hidden_sizes": [128, 128],
"time_dim": 8,
}
model_parameters = task.connect(model_parameters, name="model_parameters")
#### Model ####
# model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"])
model = GRUDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[2], time_dim=model_parameters["time_dim"], gru_hidden_size=256)
model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"])
# model = GRUDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[2], time_dim=model_parameters["time_dim"], gru_hidden_size=128)
print("Starting training ...")