Fixed policy evaluation for autoregressive

This commit is contained in:
2024-02-29 23:23:11 +01:00
parent fe1e388ffb
commit 34335cd9fe
10 changed files with 191 additions and 95 deletions

View File

@@ -25,18 +25,22 @@ class NrvDataset(Dataset):
self.sequence_length = sequence_length
self.predict_sequence_length = predict_sequence_length
self.samples_to_skip = self.skip_samples(dataframe=dataframe, full_day_skip=self.full_day_skip)
self.samples_to_skip = self.skip_samples(
dataframe=dataframe, full_day_skip=self.full_day_skip
)
total_indices = set(
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
)
self.valid_indices = sorted(list(total_indices - set(self.samples_to_skip)))
# full day indices
full_day_skipped_samples = self.skip_samples(dataframe=dataframe, full_day_skip=True)
full_day_total_indices = set(
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
full_day_skipped_samples = self.skip_samples(
dataframe=dataframe, full_day_skip=True
)
full_day_total_indices = set(range(len(dataframe) - self.sequence_length - 96))
self.full_day_valid_indices = sorted(
list(full_day_total_indices - set(full_day_skipped_samples))
)
self.full_day_valid_indices = sorted(list(full_day_total_indices - set(full_day_skipped_samples)))
self.history_features = []
if self.data_config.LOAD_HISTORY:
@@ -74,7 +78,7 @@ class NrvDataset(Dataset):
self.time_feature = torch.tensor(time_feature).float().reshape(-1)
else:
self.time_feature = None
self.nrv = torch.tensor(dataframe["nrv"].values).float().reshape(-1)
self.datetime = dataframe["datetime"]
@@ -84,12 +88,7 @@ class NrvDataset(Dataset):
nan_rows = dataframe[dataframe.isnull().any(axis=1)]
nan_indices = nan_rows.index
skip_indices = [
list(
range(
idx - self.sequence_length - 96, idx + 1
)
)
for idx in nan_indices
list(range(idx - self.sequence_length - 96, idx + 1)) for idx in nan_indices
]
skip_indices = [item for sublist in skip_indices for item in sublist]
@@ -106,10 +105,12 @@ class NrvDataset(Dataset):
skip_indices = list(set(skip_indices))
return skip_indices
def preprocess_data(self, dataframe):
return torch.tensor(dataframe[self.history_features].values).float(), torch.tensor(dataframe[self.forecast_features].values).float()
def preprocess_data(self, dataframe):
return (
torch.tensor(dataframe[self.history_features].values).float(),
torch.tensor(dataframe[self.forecast_features].values).float(),
)
def __len__(self):
return len(self.valid_indices)
@@ -117,21 +118,38 @@ class NrvDataset(Dataset):
def _get_all_data(self, idx: int):
history_df = self.dataframe.iloc[idx : idx + self.sequence_length]
forecast_df = self.dataframe.iloc[
idx + self.sequence_length : idx + self.sequence_length + self.predict_sequence_length
idx
+ self.sequence_length : idx
+ self.sequence_length
+ self.predict_sequence_length
]
return history_df, forecast_df
def __getitem__(self, idx):
actual_idx = self.valid_indices[idx]
try:
actual_idx = self.valid_indices[idx]
except IndexError:
print(f"Index {idx} not in valid indices.")
raise
# get nrv history features
nrv_features = self.nrv[actual_idx : actual_idx + self.sequence_length]
history_features = self.history_features[actual_idx : actual_idx + self.sequence_length, :]
forecast_features = self.forecast_features[actual_idx + self.sequence_length : actual_idx + self.sequence_length + self.predict_sequence_length, :]
history_features = self.history_features[
actual_idx : actual_idx + self.sequence_length, :
]
forecast_features = self.forecast_features[
actual_idx
+ self.sequence_length : actual_idx
+ self.sequence_length
+ self.predict_sequence_length,
:,
]
if self.time_feature is not None:
time_features = self.time_feature[actual_idx : actual_idx + self.sequence_length]
time_features = self.time_feature[
actual_idx : actual_idx + self.sequence_length
]
else:
time_features = None
@@ -154,7 +172,9 @@ class NrvDataset(Dataset):
all_features_list = [nrv_features.unsqueeze(1)]
if self.forecast_features.numel() > 0:
history_forecast_features = self.forecast_features[actual_idx + 1 : actual_idx + self.sequence_length + 1, :]
history_forecast_features = self.forecast_features[
actual_idx + 1 : actual_idx + self.sequence_length + 1, :
]
all_features_list.append(history_forecast_features)
if time_features is not None:
@@ -163,7 +183,12 @@ class NrvDataset(Dataset):
all_features = torch.cat(all_features_list, dim=1)
# Target sequence, flattened if necessary
nrv_target = self.nrv[actual_idx + self.sequence_length : actual_idx + self.sequence_length + self.predict_sequence_length]
nrv_target = self.nrv[
actual_idx
+ self.sequence_length : actual_idx
+ self.sequence_length
+ self.predict_sequence_length
]
# check if nan values are present
if torch.isnan(all_features).any():
@@ -188,7 +213,6 @@ class NrvDataset(Dataset):
return all_features, nrv_target
def get_batch(self, idx: list):
features = []
targets = []
@@ -216,8 +240,8 @@ class NrvDataset(Dataset):
# check if the date is in the valid indices
if date not in self.datetime.dt.date.unique():
raise ValueError(f"Date {date} not in dataset.")
idx = self.datetime[self.datetime.dt.date == date].index[0]
valid_idx = self.valid_indices.index(idx)
return valid_idx
return valid_idx