Sped up sampling 20x
This commit is contained in:
@@ -167,7 +167,10 @@ class DataProcessor:
|
||||
)
|
||||
|
||||
def get_train_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
self,
|
||||
transform: bool = True,
|
||||
predict_sequence_length: int = 96,
|
||||
shuffle: bool = True,
|
||||
):
|
||||
train_df = self.all_features.copy()
|
||||
|
||||
@@ -194,7 +197,7 @@ class DataProcessor:
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
)
|
||||
return self.get_dataloader(train_dataset)
|
||||
return self.get_dataloader(train_dataset, shuffle=shuffle)
|
||||
|
||||
def get_test_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
@@ -262,5 +265,5 @@ class DataProcessor:
|
||||
data_loader = self.get_train_dataloader(
|
||||
predict_sequence_length=self.output_size
|
||||
)
|
||||
input, _ = next(iter(data_loader))
|
||||
input, _, _ = next(iter(data_loader))
|
||||
return input.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user