other changes

This commit is contained in:
Victor Mylle
2024-01-15 12:31:56 +00:00
parent 5f2418a205
commit 67cc6d4bb9
7 changed files with 855 additions and 482 deletions

View File

@@ -0,0 +1,53 @@
import argparse
from clearml import Task, Model
from src.data import DataProcessor, DataConfig
import torch
# argparse to parse task id and model type
parser = argparse.ArgumentParser()
parser.add_argument('--task_id', type=int, default=None)
parser.add_argument('--model_type', type=str, default=None)
args = parser.parse_args()
assert args.task_id is not None, "Please specify task id"
assert args.model_type is not None, "Please specify model type"
def load_model(task_id: str):
"""
Load model from task id
"""
task = Task.get_task(task_id=task_id)
configuration = task.get_parameters_as_dict()
data_features = configuration['data_features']
### Data Config ###
data_config = DataConfig()
for key, value in data_features.items():
setattr(data_config, key, bool(value))
data_config.PV_FORECAST = False
data_config.PV_HISTORY = False
data_config.QUARTER = False
data_config.DAY_OF_WEEK = False
### Data Processor ###
data_processor = DataProcessor(data_config, path="../../", lstm=False)
data_processor.set_batch_size(8192)
data_processor.set_full_day_skip(True)
### Model ###
output_model_id = task.output_models_id["checkpoint"]
clearml_model = Model(model_id=output_model_id)
filename = clearml_model.get_weights()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(filename)
model.to(device)
model.eval()
_, test_loader = data_processor.get_dataloaders(
predict_sequence_length=96
)
return configuration, model, test_loader