Added functions to compare policies
This commit is contained in:
@@ -207,4 +207,11 @@ class NrvDataset(Dataset):
|
||||
return torch.stack(features), torch.stack(targets)
|
||||
|
||||
def get_idx_for_date(self, date: datetime.date):
|
||||
return self.datetime[self.datetime.dt.date == date].index[0]
|
||||
# check if the date is in the valid indices
|
||||
if date not in self.datetime.dt.date.unique():
|
||||
raise ValueError(f"Date {date} not in dataset.")
|
||||
|
||||
idx = self.datetime[self.datetime.dt.date == date].index[0]
|
||||
|
||||
valid_idx = self.valid_indices.index(idx)
|
||||
return valid_idx
|
||||
Reference in New Issue
Block a user