Sped up sampling 20x
This commit is contained in:
@@ -62,7 +62,7 @@ class NrvDataset(Dataset):
|
||||
# get indices of all 00:15 timestamps
|
||||
if self.full_day_skip:
|
||||
start_of_day_indices = self.dataframe[
|
||||
self.dataframe["datetime"].dt.time == pd.Timestamp("00:15:00").time()
|
||||
self.dataframe["datetime"].dt.time != pd.Timestamp("00:15:00").time()
|
||||
].index
|
||||
skip_indices.extend(start_of_day_indices)
|
||||
skip_indices = list(set(skip_indices))
|
||||
@@ -147,7 +147,7 @@ class NrvDataset(Dataset):
|
||||
print(f"Actual index: {actual_idx}")
|
||||
raise ValueError("There are nan values in the features.")
|
||||
|
||||
return all_features, nrv_target
|
||||
return all_features, nrv_target, idx
|
||||
|
||||
def random_day_autoregressive(self, idx: int):
|
||||
idx = self.valid_indices[idx]
|
||||
@@ -205,3 +205,26 @@ class NrvDataset(Dataset):
|
||||
|
||||
all_features = torch.cat(features, dim=0)
|
||||
return all_features, target
|
||||
|
||||
def get_batch(self, idx: list):
|
||||
features = []
|
||||
targets = []
|
||||
for i in idx:
|
||||
f, t, _ = self.__getitem__(i)
|
||||
features.append(f)
|
||||
targets.append(t)
|
||||
|
||||
return torch.stack(features), torch.stack(targets)
|
||||
|
||||
def get_batch_autoregressive(self, idx: list):
|
||||
features = []
|
||||
targets = []
|
||||
for i in idx:
|
||||
f, t = self.random_day_autoregressive(i)
|
||||
features.append(f)
|
||||
targets.append(t)
|
||||
|
||||
if None in features:
|
||||
return None, torch.stack(targets)
|
||||
|
||||
return torch.stack(features), torch.stack(targets)
|
||||
|
||||
@@ -167,7 +167,10 @@ class DataProcessor:
|
||||
)
|
||||
|
||||
def get_train_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
self,
|
||||
transform: bool = True,
|
||||
predict_sequence_length: int = 96,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
train_df = self.all_features.copy()
|
||||
|
||||
@@ -194,7 +197,7 @@ class DataProcessor:
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
)
|
||||
return self.get_dataloader(train_dataset)
|
||||
return self.get_dataloader(train_dataset, shuffle=shuffle)
|
||||
|
||||
def get_test_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
@@ -262,5 +265,5 @@ class DataProcessor:
|
||||
data_loader = self.get_train_dataloader(
|
||||
predict_sequence_length=self.output_size
|
||||
)
|
||||
input, _ = next(iter(data_loader))
|
||||
input, _, _ = next(iter(data_loader))
|
||||
return input.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user