Sped up sampling 20x

This commit is contained in:
Victor Mylle
2023-11-25 18:09:42 +00:00
parent 5de3f64a1a
commit 300f268286
10 changed files with 498 additions and 238 deletions

View File

@@ -62,7 +62,7 @@ class NrvDataset(Dataset):
# get indices of all 00:15 timestamps
if self.full_day_skip:
start_of_day_indices = self.dataframe[
self.dataframe["datetime"].dt.time == pd.Timestamp("00:15:00").time()
self.dataframe["datetime"].dt.time != pd.Timestamp("00:15:00").time()
].index
skip_indices.extend(start_of_day_indices)
skip_indices = list(set(skip_indices))
@@ -147,7 +147,7 @@ class NrvDataset(Dataset):
print(f"Actual index: {actual_idx}")
raise ValueError("There are nan values in the features.")
return all_features, nrv_target
return all_features, nrv_target, idx
def random_day_autoregressive(self, idx: int):
idx = self.valid_indices[idx]
@@ -205,3 +205,26 @@ class NrvDataset(Dataset):
all_features = torch.cat(features, dim=0)
return all_features, target
def get_batch(self, idx: list):
features = []
targets = []
for i in idx:
f, t, _ = self.__getitem__(i)
features.append(f)
targets.append(t)
return torch.stack(features), torch.stack(targets)
def get_batch_autoregressive(self, idx: list):
features = []
targets = []
for i in idx:
f, t = self.random_day_autoregressive(i)
features.append(f)
targets.append(t)
if None in features:
return None, torch.stack(targets)
return torch.stack(features), torch.stack(targets)

View File

@@ -167,7 +167,10 @@ class DataProcessor:
)
def get_train_dataloader(
self, transform: bool = True, predict_sequence_length: int = 96
self,
transform: bool = True,
predict_sequence_length: int = 96,
shuffle: bool = True,
):
train_df = self.all_features.copy()
@@ -194,7 +197,7 @@ class DataProcessor:
full_day_skip=self.full_day_skip,
predict_sequence_length=predict_sequence_length,
)
return self.get_dataloader(train_dataset)
return self.get_dataloader(train_dataset, shuffle=shuffle)
def get_test_dataloader(
self, transform: bool = True, predict_sequence_length: int = 96
@@ -262,5 +265,5 @@ class DataProcessor:
data_loader = self.get_train_dataloader(
predict_sequence_length=self.output_size
)
input, _ = next(iter(data_loader))
input, _, _ = next(iter(data_loader))
return input.shape[-1]