Fixed policy evaluation for autoregressive
This commit is contained in:
@@ -25,18 +25,22 @@ class NrvDataset(Dataset):
|
||||
self.sequence_length = sequence_length
|
||||
self.predict_sequence_length = predict_sequence_length
|
||||
|
||||
self.samples_to_skip = self.skip_samples(dataframe=dataframe, full_day_skip=self.full_day_skip)
|
||||
self.samples_to_skip = self.skip_samples(
|
||||
dataframe=dataframe, full_day_skip=self.full_day_skip
|
||||
)
|
||||
total_indices = set(
|
||||
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
|
||||
)
|
||||
self.valid_indices = sorted(list(total_indices - set(self.samples_to_skip)))
|
||||
|
||||
# full day indices
|
||||
full_day_skipped_samples = self.skip_samples(dataframe=dataframe, full_day_skip=True)
|
||||
full_day_total_indices = set(
|
||||
range(len(dataframe) - self.sequence_length - self.predict_sequence_length)
|
||||
full_day_skipped_samples = self.skip_samples(
|
||||
dataframe=dataframe, full_day_skip=True
|
||||
)
|
||||
|
||||
full_day_total_indices = set(range(len(dataframe) - self.sequence_length - 96))
|
||||
self.full_day_valid_indices = sorted(
|
||||
list(full_day_total_indices - set(full_day_skipped_samples))
|
||||
)
|
||||
self.full_day_valid_indices = sorted(list(full_day_total_indices - set(full_day_skipped_samples)))
|
||||
|
||||
self.history_features = []
|
||||
if self.data_config.LOAD_HISTORY:
|
||||
@@ -74,7 +78,7 @@ class NrvDataset(Dataset):
|
||||
self.time_feature = torch.tensor(time_feature).float().reshape(-1)
|
||||
else:
|
||||
self.time_feature = None
|
||||
|
||||
|
||||
self.nrv = torch.tensor(dataframe["nrv"].values).float().reshape(-1)
|
||||
self.datetime = dataframe["datetime"]
|
||||
|
||||
@@ -84,12 +88,7 @@ class NrvDataset(Dataset):
|
||||
nan_rows = dataframe[dataframe.isnull().any(axis=1)]
|
||||
nan_indices = nan_rows.index
|
||||
skip_indices = [
|
||||
list(
|
||||
range(
|
||||
idx - self.sequence_length - 96, idx + 1
|
||||
)
|
||||
)
|
||||
for idx in nan_indices
|
||||
list(range(idx - self.sequence_length - 96, idx + 1)) for idx in nan_indices
|
||||
]
|
||||
|
||||
skip_indices = [item for sublist in skip_indices for item in sublist]
|
||||
@@ -106,10 +105,12 @@ class NrvDataset(Dataset):
|
||||
skip_indices = list(set(skip_indices))
|
||||
|
||||
return skip_indices
|
||||
|
||||
def preprocess_data(self, dataframe):
|
||||
return torch.tensor(dataframe[self.history_features].values).float(), torch.tensor(dataframe[self.forecast_features].values).float()
|
||||
|
||||
def preprocess_data(self, dataframe):
|
||||
return (
|
||||
torch.tensor(dataframe[self.history_features].values).float(),
|
||||
torch.tensor(dataframe[self.forecast_features].values).float(),
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.valid_indices)
|
||||
@@ -117,21 +118,38 @@ class NrvDataset(Dataset):
|
||||
def _get_all_data(self, idx: int):
|
||||
history_df = self.dataframe.iloc[idx : idx + self.sequence_length]
|
||||
forecast_df = self.dataframe.iloc[
|
||||
idx + self.sequence_length : idx + self.sequence_length + self.predict_sequence_length
|
||||
idx
|
||||
+ self.sequence_length : idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
return history_df, forecast_df
|
||||
|
||||
def __getitem__(self, idx):
|
||||
actual_idx = self.valid_indices[idx]
|
||||
|
||||
try:
|
||||
actual_idx = self.valid_indices[idx]
|
||||
except IndexError:
|
||||
print(f"Index {idx} not in valid indices.")
|
||||
raise
|
||||
|
||||
# get nrv history features
|
||||
nrv_features = self.nrv[actual_idx : actual_idx + self.sequence_length]
|
||||
|
||||
history_features = self.history_features[actual_idx : actual_idx + self.sequence_length, :]
|
||||
forecast_features = self.forecast_features[actual_idx + self.sequence_length : actual_idx + self.sequence_length + self.predict_sequence_length, :]
|
||||
history_features = self.history_features[
|
||||
actual_idx : actual_idx + self.sequence_length, :
|
||||
]
|
||||
forecast_features = self.forecast_features[
|
||||
actual_idx
|
||||
+ self.sequence_length : actual_idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length,
|
||||
:,
|
||||
]
|
||||
|
||||
if self.time_feature is not None:
|
||||
time_features = self.time_feature[actual_idx : actual_idx + self.sequence_length]
|
||||
time_features = self.time_feature[
|
||||
actual_idx : actual_idx + self.sequence_length
|
||||
]
|
||||
else:
|
||||
time_features = None
|
||||
|
||||
@@ -154,7 +172,9 @@ class NrvDataset(Dataset):
|
||||
all_features_list = [nrv_features.unsqueeze(1)]
|
||||
|
||||
if self.forecast_features.numel() > 0:
|
||||
history_forecast_features = self.forecast_features[actual_idx + 1 : actual_idx + self.sequence_length + 1, :]
|
||||
history_forecast_features = self.forecast_features[
|
||||
actual_idx + 1 : actual_idx + self.sequence_length + 1, :
|
||||
]
|
||||
all_features_list.append(history_forecast_features)
|
||||
|
||||
if time_features is not None:
|
||||
@@ -163,7 +183,12 @@ class NrvDataset(Dataset):
|
||||
all_features = torch.cat(all_features_list, dim=1)
|
||||
|
||||
# Target sequence, flattened if necessary
|
||||
nrv_target = self.nrv[actual_idx + self.sequence_length : actual_idx + self.sequence_length + self.predict_sequence_length]
|
||||
nrv_target = self.nrv[
|
||||
actual_idx
|
||||
+ self.sequence_length : actual_idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
|
||||
# check if nan values are present
|
||||
if torch.isnan(all_features).any():
|
||||
@@ -188,7 +213,6 @@ class NrvDataset(Dataset):
|
||||
|
||||
return all_features, nrv_target
|
||||
|
||||
|
||||
def get_batch(self, idx: list):
|
||||
features = []
|
||||
targets = []
|
||||
@@ -216,8 +240,8 @@ class NrvDataset(Dataset):
|
||||
# check if the date is in the valid indices
|
||||
if date not in self.datetime.dt.date.unique():
|
||||
raise ValueError(f"Date {date} not in dataset.")
|
||||
|
||||
|
||||
idx = self.datetime[self.datetime.dt.date == date].index[0]
|
||||
|
||||
valid_idx = self.valid_indices.index(idx)
|
||||
return valid_idx
|
||||
return valid_idx
|
||||
|
||||
Reference in New Issue
Block a user