Added functions to compare policies

This commit is contained in:
Victor Mylle
2023-12-19 16:22:13 +00:00
parent fee948cc09
commit d0fa815b68
5 changed files with 715 additions and 88 deletions

View File

@@ -207,4 +207,11 @@ class NrvDataset(Dataset):
return torch.stack(features), torch.stack(targets)
def get_idx_for_date(self, date: datetime.date):
return self.datetime[self.datetime.dt.date == date].index[0]
# check if the date is in the valid indices
if date not in self.datetime.dt.date.unique():
raise ValueError(f"Date {date} not in dataset.")
idx = self.datetime[self.datetime.dt.date == date].index[0]
valid_idx = self.valid_indices.index(idx)
return valid_idx