other changes
This commit is contained in:
53
src/policies/policy_executer.py
Normal file
53
src/policies/policy_executer.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import argparse
|
||||
from clearml import Task, Model
|
||||
from src.data import DataProcessor, DataConfig
|
||||
import torch
|
||||
|
||||
# argparse to parse task id and model type
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task_id', type=int, default=None)
|
||||
parser.add_argument('--model_type', type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.task_id is not None, "Please specify task id"
|
||||
assert args.model_type is not None, "Please specify model type"
|
||||
|
||||
def load_model(task_id: str):
|
||||
"""
|
||||
Load model from task id
|
||||
"""
|
||||
task = Task.get_task(task_id=task_id)
|
||||
|
||||
configuration = task.get_parameters_as_dict()
|
||||
data_features = configuration['data_features']
|
||||
|
||||
### Data Config ###
|
||||
data_config = DataConfig()
|
||||
for key, value in data_features.items():
|
||||
setattr(data_config, key, bool(value))
|
||||
data_config.PV_FORECAST = False
|
||||
data_config.PV_HISTORY = False
|
||||
data_config.QUARTER = False
|
||||
data_config.DAY_OF_WEEK = False
|
||||
|
||||
### Data Processor ###
|
||||
data_processor = DataProcessor(data_config, path="../../", lstm=False)
|
||||
data_processor.set_batch_size(8192)
|
||||
data_processor.set_full_day_skip(True)
|
||||
|
||||
### Model ###
|
||||
output_model_id = task.output_models_id["checkpoint"]
|
||||
clearml_model = Model(model_id=output_model_id)
|
||||
filename = clearml_model.get_weights()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = torch.load(filename)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
_, test_loader = data_processor.get_dataloaders(
|
||||
predict_sequence_length=96
|
||||
)
|
||||
|
||||
return configuration, model, test_loader
|
||||
Reference in New Issue
Block a user