Lot of changes
This commit is contained in:
@@ -2,8 +2,16 @@ import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class NrvDataset(Dataset):
|
||||
def __init__(self, dataframe, data_config, full_day_skip: bool = False, sequence_length=96, predict_sequence_length=96):
|
||||
def __init__(
|
||||
self,
|
||||
dataframe,
|
||||
data_config,
|
||||
full_day_skip: bool = False,
|
||||
sequence_length=96,
|
||||
predict_sequence_length=96,
|
||||
):
|
||||
self.data_config = data_config
|
||||
self.dataframe = dataframe
|
||||
self.full_day_skip = full_day_skip
|
||||
@@ -11,26 +19,40 @@ class NrvDataset(Dataset):
|
||||
# reset dataframe index
|
||||
self.dataframe.reset_index(drop=True, inplace=True)
|
||||
|
||||
self.nrv = torch.tensor(dataframe['nrv'].to_numpy(), dtype=torch.float32)
|
||||
self.load_forecast = torch.tensor(dataframe['load_forecast'].to_numpy(), dtype=torch.float32)
|
||||
self.total_load = torch.tensor(dataframe['total_load'].to_numpy(), dtype=torch.float32)
|
||||
self.pv_gen_forecast = torch.tensor(dataframe['pv_forecast'].to_numpy(), dtype=torch.float32)
|
||||
self.wind_gen_forecast = torch.tensor(dataframe['wind_forecast'].to_numpy(), dtype=torch.float32)
|
||||
self.nrv = torch.tensor(dataframe["nrv"].to_numpy(), dtype=torch.float32)
|
||||
self.load_forecast = torch.tensor(
|
||||
dataframe["load_forecast"].to_numpy(), dtype=torch.float32
|
||||
)
|
||||
self.total_load = torch.tensor(
|
||||
dataframe["total_load"].to_numpy(), dtype=torch.float32
|
||||
)
|
||||
self.pv_gen_forecast = torch.tensor(
|
||||
dataframe["pv_forecast"].to_numpy(), dtype=torch.float32
|
||||
)
|
||||
self.wind_gen_forecast = torch.tensor(
|
||||
dataframe["wind_forecast"].to_numpy(), dtype=torch.float32
|
||||
)
|
||||
|
||||
self.sequence_length = sequence_length
|
||||
self.predict_sequence_length = predict_sequence_length
|
||||
|
||||
self.samples_to_skip = self.skip_samples()
|
||||
total_indices = set(range(len(self.nrv) - self.sequence_length - self.predict_sequence_length))
|
||||
total_indices = set(
|
||||
range(len(self.nrv) - self.sequence_length - self.predict_sequence_length)
|
||||
)
|
||||
self.valid_indices = sorted(list(total_indices - set(self.samples_to_skip)))
|
||||
|
||||
### TODO: Option to only use full day samples ###
|
||||
### skip all samples between is the easiest way I think (not most efficient though) ###
|
||||
|
||||
def skip_samples(self):
|
||||
nan_rows = self.dataframe[self.dataframe.isnull().any(axis=1)]
|
||||
nan_indices = nan_rows.index
|
||||
skip_indices = [list(range(idx-self.sequence_length-self.predict_sequence_length, idx+1)) for idx in nan_indices]
|
||||
skip_indices = [
|
||||
list(
|
||||
range(
|
||||
idx - self.sequence_length - self.predict_sequence_length, idx + 1
|
||||
)
|
||||
)
|
||||
for idx in nan_indices
|
||||
]
|
||||
|
||||
skip_indices = [item for sublist in skip_indices for item in sublist]
|
||||
skip_indices = list(set(skip_indices))
|
||||
@@ -39,7 +61,9 @@ class NrvDataset(Dataset):
|
||||
# add indices that are not the start of a day (00:15) to the skip indices (use datetime column)
|
||||
# get indices of all 00:15 timestamps
|
||||
if self.full_day_skip:
|
||||
start_of_day_indices = self.dataframe[self.dataframe['datetime'].dt.time == pd.Timestamp('00:15:00').time()].index
|
||||
start_of_day_indices = self.dataframe[
|
||||
self.dataframe["datetime"].dt.time == pd.Timestamp("00:15:00").time()
|
||||
].index
|
||||
skip_indices.extend(start_of_day_indices)
|
||||
skip_indices = list(set(skip_indices))
|
||||
|
||||
@@ -47,47 +71,75 @@ class NrvDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.valid_indices)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
actual_idx = self.valid_indices[idx]
|
||||
features = []
|
||||
|
||||
if self.data_config.NRV_HISTORY:
|
||||
nrv = self.nrv[actual_idx:actual_idx+self.sequence_length]
|
||||
nrv = self.nrv[actual_idx : actual_idx + self.sequence_length]
|
||||
features.append(nrv.view(-1))
|
||||
|
||||
if self.data_config.LOAD_HISTORY:
|
||||
load_history = self.total_load[actual_idx:actual_idx+self.sequence_length]
|
||||
load_history = self.total_load[
|
||||
actual_idx : actual_idx + self.sequence_length
|
||||
]
|
||||
features.append(load_history.view(-1))
|
||||
|
||||
if self.data_config.PV_HISTORY:
|
||||
pv_history = self.pv_gen_forecast[actual_idx:actual_idx+self.sequence_length]
|
||||
pv_history = self.pv_gen_forecast[
|
||||
actual_idx : actual_idx + self.sequence_length
|
||||
]
|
||||
features.append(pv_history.view(-1))
|
||||
|
||||
if self.data_config.WIND_HISTORY:
|
||||
wind_history = self.wind_gen_forecast[actual_idx:actual_idx+self.sequence_length]
|
||||
wind_history = self.wind_gen_forecast[
|
||||
actual_idx : actual_idx + self.sequence_length
|
||||
]
|
||||
features.append(wind_history.view(-1))
|
||||
|
||||
if self.data_config.LOAD_FORECAST:
|
||||
load_forecast = self.load_forecast[actual_idx+self.sequence_length:actual_idx+self.sequence_length+self.predict_sequence_length]
|
||||
load_forecast = self.load_forecast[
|
||||
actual_idx
|
||||
+ self.sequence_length : actual_idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
features.append(load_forecast.view(-1))
|
||||
|
||||
if self.data_config.PV_FORECAST:
|
||||
pv_forecast = self.pv_gen_forecast[actual_idx+self.sequence_length:actual_idx+self.sequence_length+self.predict_sequence_length]
|
||||
pv_forecast = self.pv_gen_forecast[
|
||||
actual_idx
|
||||
+ self.sequence_length : actual_idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
features.append(pv_forecast.view(-1))
|
||||
|
||||
if self.data_config.WIND_FORECAST:
|
||||
wind_forecast = self.wind_gen_forecast[actual_idx+self.sequence_length:actual_idx+self.sequence_length+self.predict_sequence_length]
|
||||
wind_forecast = self.wind_gen_forecast[
|
||||
actual_idx
|
||||
+ self.sequence_length : actual_idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
features.append(wind_forecast.view(-1))
|
||||
|
||||
if not features:
|
||||
raise ValueError("No features are configured to be included in the dataset.")
|
||||
raise ValueError(
|
||||
"No features are configured to be included in the dataset."
|
||||
)
|
||||
|
||||
# Concatenate along dimension 0 to create a one-dimensional feature vector
|
||||
all_features = torch.cat(features, dim=0)
|
||||
|
||||
|
||||
# Target sequence, flattened if necessary
|
||||
nrv_target = self.nrv[actual_idx+self.sequence_length:actual_idx+self.sequence_length+self.predict_sequence_length].view(-1)
|
||||
nrv_target = self.nrv[
|
||||
actual_idx
|
||||
+ self.sequence_length : actual_idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
].view(-1)
|
||||
|
||||
# check if nan values are present
|
||||
if torch.isnan(all_features).any():
|
||||
@@ -103,35 +155,53 @@ class NrvDataset(Dataset):
|
||||
|
||||
# we already have the NRV history with the newly predicted values, so we don't need to include the last 96 values
|
||||
if self.data_config.LOAD_HISTORY:
|
||||
load_history = self.total_load[idx:idx+self.sequence_length]
|
||||
load_history = self.total_load[idx : idx + self.sequence_length]
|
||||
features.append(load_history.view(-1))
|
||||
|
||||
if self.data_config.PV_HISTORY:
|
||||
pv_history = self.pv_gen_forecast[idx:idx+self.sequence_length]
|
||||
pv_history = self.pv_gen_forecast[idx : idx + self.sequence_length]
|
||||
features.append(pv_history.view(-1))
|
||||
|
||||
if self.data_config.WIND_HISTORY:
|
||||
wind_history = self.wind_gen_forecast[idx:idx+self.sequence_length]
|
||||
wind_history = self.wind_gen_forecast[idx : idx + self.sequence_length]
|
||||
features.append(wind_history.view(-1))
|
||||
|
||||
if self.data_config.LOAD_FORECAST:
|
||||
load_forecast = self.load_forecast[idx+self.sequence_length:idx+self.sequence_length+self.predict_sequence_length]
|
||||
load_forecast = self.load_forecast[
|
||||
idx
|
||||
+ self.sequence_length : idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
features.append(load_forecast.view(-1))
|
||||
|
||||
if self.data_config.PV_FORECAST:
|
||||
pv_forecast = self.pv_gen_forecast[idx+self.sequence_length:idx+self.sequence_length+self.predict_sequence_length]
|
||||
pv_forecast = self.pv_gen_forecast[
|
||||
idx
|
||||
+ self.sequence_length : idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
features.append(pv_forecast.view(-1))
|
||||
|
||||
if self.data_config.WIND_FORECAST:
|
||||
wind_forecast = self.wind_gen_forecast[idx+self.sequence_length:idx+self.sequence_length+self.predict_sequence_length]
|
||||
wind_forecast = self.wind_gen_forecast[
|
||||
idx
|
||||
+ self.sequence_length : idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
features.append(wind_forecast.view(-1))
|
||||
|
||||
|
||||
|
||||
target = self.nrv[idx+self.sequence_length:idx+self.sequence_length+self.predict_sequence_length]
|
||||
target = self.nrv[
|
||||
idx
|
||||
+ self.sequence_length : idx
|
||||
+ self.sequence_length
|
||||
+ self.predict_sequence_length
|
||||
]
|
||||
|
||||
if len(features) == 0:
|
||||
return None, target
|
||||
|
||||
all_features = torch.cat(features, dim=0)
|
||||
return all_features, target
|
||||
return all_features, target
|
||||
|
||||
@@ -12,6 +12,7 @@ forecast_data_path = "../../data/load_forecast.csv"
|
||||
pv_forecast_data_path = "../../data/pv_gen_forecast.csv"
|
||||
wind_forecast_data_path = "../../data/wind_gen_forecast.csv"
|
||||
|
||||
|
||||
class DataConfig:
|
||||
def __init__(self):
|
||||
self.NRV_HISTORY: bool = True
|
||||
@@ -28,11 +29,20 @@ class DataConfig:
|
||||
self.WIND_FORECAST: bool = False
|
||||
self.WIND_HISTORY: bool = False
|
||||
|
||||
### TIME ###
|
||||
self.YEAR: bool = False
|
||||
self.DAY: bool = False
|
||||
self.QUARTER: bool = False
|
||||
|
||||
|
||||
class DataProcessor:
|
||||
def __init__(self, data_config: DataConfig):
|
||||
self.batch_size = 2048
|
||||
|
||||
self.train_range = (-np.inf, datetime(year=2022, month=11, day=30, tzinfo=pytz.UTC))
|
||||
self.train_range = (
|
||||
-np.inf,
|
||||
datetime(year=2022, month=11, day=30, tzinfo=pytz.UTC),
|
||||
)
|
||||
self.test_range = (datetime(year=2023, month=1, day=1, tzinfo=pytz.UTC), np.inf)
|
||||
|
||||
self.update_range_str()
|
||||
@@ -42,9 +52,17 @@ class DataProcessor:
|
||||
self.pv_forecast = self.get_pv_forecast()
|
||||
self.wind_forecast = self.get_wind_forecast()
|
||||
|
||||
self.all_features = self.history_features.merge(self.future_features, on='datetime', how='left')
|
||||
self.all_features = self.all_features.merge(self.pv_forecast, on='datetime', how='left')
|
||||
self.all_features = self.all_features.merge(self.wind_forecast, on='datetime', how='left')
|
||||
self.all_features = self.history_features.merge(
|
||||
self.future_features, on="datetime", how="left"
|
||||
)
|
||||
self.all_features = self.all_features.merge(
|
||||
self.pv_forecast, on="datetime", how="left"
|
||||
)
|
||||
self.all_features = self.all_features.merge(
|
||||
self.wind_forecast, on="datetime", how="left"
|
||||
)
|
||||
|
||||
self.output_size = 96
|
||||
|
||||
self.data_config = data_config
|
||||
|
||||
@@ -59,6 +77,9 @@ class DataProcessor:
|
||||
def set_full_day_skip(self, full_day_skip: bool):
|
||||
self.full_day_skip = full_day_skip
|
||||
|
||||
def set_output_size(self, output_size: int):
|
||||
self.output_size = output_size
|
||||
|
||||
def set_train_range(self, train_range: tuple):
|
||||
self.train_range = train_range
|
||||
self.update_range_str()
|
||||
@@ -68,106 +89,178 @@ class DataProcessor:
|
||||
self.update_range_str()
|
||||
|
||||
def update_range_str(self):
|
||||
self.train_range_start = str(self.train_range[0]) if self.train_range[0] != -np.inf else "-inf"
|
||||
self.train_range_end = str(self.train_range[1]) if self.train_range[1] != np.inf else "inf"
|
||||
self.test_range_start = str(self.test_range[0]) if self.test_range[0] != -np.inf else "-inf"
|
||||
self.test_range_end = str(self.test_range[1]) if self.test_range[1] != np.inf else "inf"
|
||||
self.train_range_start = (
|
||||
str(self.train_range[0]) if self.train_range[0] != -np.inf else "-inf"
|
||||
)
|
||||
self.train_range_end = (
|
||||
str(self.train_range[1]) if self.train_range[1] != np.inf else "inf"
|
||||
)
|
||||
self.test_range_start = (
|
||||
str(self.test_range[0]) if self.test_range[0] != -np.inf else "-inf"
|
||||
)
|
||||
self.test_range_end = (
|
||||
str(self.test_range[1]) if self.test_range[1] != np.inf else "inf"
|
||||
)
|
||||
|
||||
def get_nrv_history(self):
|
||||
df = pd.read_csv(history_data_path, delimiter=';')
|
||||
df = df[['datetime', 'netregulationvolume']]
|
||||
df = df.rename(columns={'netregulationvolume': 'nrv'})
|
||||
df['datetime'] = pd.to_datetime(df['datetime'])
|
||||
counts = df['datetime'].dt.date.value_counts().sort_index()
|
||||
df = df[df['datetime'].dt.date.isin(counts[counts == 96].index)]
|
||||
df = pd.read_csv(history_data_path, delimiter=";")
|
||||
df = df[["datetime", "netregulationvolume"]]
|
||||
df = df.rename(columns={"netregulationvolume": "nrv"})
|
||||
df["datetime"] = pd.to_datetime(df["datetime"])
|
||||
counts = df["datetime"].dt.date.value_counts().sort_index()
|
||||
df = df[df["datetime"].dt.date.isin(counts[counts == 96].index)]
|
||||
|
||||
df.sort_values(by="datetime", inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
def get_load_forecast(self):
|
||||
df = pd.read_csv(forecast_data_path, delimiter=';')
|
||||
df = df.rename(columns={'Day-ahead 6PM forecast': 'load_forecast', 'Datetime': 'datetime', 'Total Load': 'total_load'})
|
||||
df = df[['datetime', 'load_forecast', 'total_load']]
|
||||
|
||||
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
|
||||
df = pd.read_csv(forecast_data_path, delimiter=";")
|
||||
df = df.rename(
|
||||
columns={
|
||||
"Day-ahead 6PM forecast": "load_forecast",
|
||||
"Datetime": "datetime",
|
||||
"Total Load": "total_load",
|
||||
}
|
||||
)
|
||||
df = df[["datetime", "load_forecast", "total_load"]]
|
||||
|
||||
df["datetime"] = pd.to_datetime(df["datetime"], utc=True)
|
||||
df.sort_values(by="datetime", inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
def get_pv_forecast(self):
|
||||
df = pd.read_csv(pv_forecast_data_path, delimiter=';')
|
||||
df = pd.read_csv(pv_forecast_data_path, delimiter=";")
|
||||
|
||||
df = df.rename(columns={'dayahead11hforecast': 'pv_forecast', 'Datetime': 'datetime'})
|
||||
df = df[['datetime', 'pv_forecast']]
|
||||
df = df.rename(
|
||||
columns={"dayahead11hforecast": "pv_forecast", "Datetime": "datetime"}
|
||||
)
|
||||
df = df[["datetime", "pv_forecast"]]
|
||||
|
||||
df = df.groupby('datetime').mean().reset_index()
|
||||
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
|
||||
df = df.groupby("datetime").mean().reset_index()
|
||||
df["datetime"] = pd.to_datetime(df["datetime"], utc=True)
|
||||
df.sort_values(by="datetime", inplace=True)
|
||||
return df
|
||||
|
||||
def get_wind_forecast(self):
|
||||
df = pd.read_csv(wind_forecast_data_path, delimiter=';')
|
||||
df = pd.read_csv(wind_forecast_data_path, delimiter=";")
|
||||
|
||||
df = df.rename(columns={'dayaheadforecast': 'wind_forecast', 'datetime': 'datetime'})
|
||||
df = df[['datetime', 'wind_forecast']]
|
||||
df = df.rename(
|
||||
columns={"dayaheadforecast": "wind_forecast", "datetime": "datetime"}
|
||||
)
|
||||
df = df[["datetime", "wind_forecast"]]
|
||||
|
||||
# remove nan rows
|
||||
df = df[~df['wind_forecast'].isnull()]
|
||||
df = df[~df["wind_forecast"].isnull()]
|
||||
|
||||
df = df.groupby('datetime').mean().reset_index()
|
||||
df['datetime'] = pd.to_datetime(df['datetime'], utc=True)
|
||||
df = df.groupby("datetime").mean().reset_index()
|
||||
df["datetime"] = pd.to_datetime(df["datetime"], utc=True)
|
||||
df.sort_values(by="datetime", inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
def set_batch_size(self, batch_size: int):
|
||||
self.batch_size = batch_size
|
||||
|
||||
def get_dataloader(self, dataset, shuffle: bool = True):
|
||||
batch_size = len(dataset) if self.batch_size is None else self.batch_size
|
||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, num_workers=4
|
||||
)
|
||||
|
||||
def get_train_dataloader(self, transform: bool = True, predict_sequence_length: int = 96):
|
||||
def get_train_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
):
|
||||
train_df = self.all_features.copy()
|
||||
|
||||
if self.train_range[0] != -np.inf:
|
||||
train_df = train_df[(train_df['datetime'] >= self.train_range[0])]
|
||||
|
||||
if self.train_range[1] != np.inf:
|
||||
train_df = train_df[(train_df['datetime'] <= self.train_range[1])]
|
||||
train_df = train_df[(train_df["datetime"] >= self.train_range[0])]
|
||||
|
||||
if self.train_range[1] != np.inf:
|
||||
train_df = train_df[(train_df["datetime"] <= self.train_range[1])]
|
||||
|
||||
if transform:
|
||||
train_df['nrv'] = self.nrv_scaler.fit_transform(train_df['nrv'].values.reshape(-1, 1)).reshape(-1)
|
||||
train_df['load_forecast'] = self.load_forecast_scaler.fit_transform(train_df['load_forecast'].values.reshape(-1, 1)).reshape(-1)
|
||||
train_df['total_load'] = self.load_forecast_scaler.transform(train_df['total_load'].values.reshape(-1, 1)).reshape(-1)
|
||||
|
||||
train_dataset = NrvDataset(train_df, data_config=self.data_config, full_day_skip=self.full_day_skip, predict_sequence_length=predict_sequence_length)
|
||||
train_df["nrv"] = self.nrv_scaler.fit_transform(
|
||||
train_df["nrv"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
train_df["load_forecast"] = self.load_forecast_scaler.fit_transform(
|
||||
train_df["load_forecast"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
train_df["total_load"] = self.load_forecast_scaler.transform(
|
||||
train_df["total_load"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
|
||||
train_dataset = NrvDataset(
|
||||
train_df,
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
)
|
||||
return self.get_dataloader(train_dataset)
|
||||
|
||||
def get_test_dataloader(self, transform: bool = True, predict_sequence_length: int = 96):
|
||||
|
||||
def get_test_dataloader(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
):
|
||||
test_df = self.all_features.copy()
|
||||
|
||||
if self.test_range[0] != -np.inf:
|
||||
test_df = test_df[(test_df['datetime'] >= self.test_range[0])]
|
||||
|
||||
if self.test_range[1] != np.inf:
|
||||
test_df = test_df[(test_df['datetime'] <= self.test_range[1])]
|
||||
test_df = test_df[(test_df["datetime"] >= self.test_range[0])]
|
||||
|
||||
if self.test_range[1] != np.inf:
|
||||
test_df = test_df[(test_df["datetime"] <= self.test_range[1])]
|
||||
|
||||
if transform:
|
||||
test_df['nrv'] = self.nrv_scaler.transform(test_df['nrv'].values.reshape(-1, 1)).reshape(-1)
|
||||
test_df['load_forecast'] = self.load_forecast_scaler.transform(test_df['load_forecast'].values.reshape(-1, 1)).reshape(-1)
|
||||
test_df['total_load'] = self.load_forecast_scaler.transform(test_df['total_load'].values.reshape(-1, 1)).reshape(-1)
|
||||
test_df["nrv"] = self.nrv_scaler.transform(
|
||||
test_df["nrv"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
test_df["load_forecast"] = self.load_forecast_scaler.transform(
|
||||
test_df["load_forecast"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
test_df["total_load"] = self.load_forecast_scaler.transform(
|
||||
test_df["total_load"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
|
||||
test_dataset = NrvDataset(test_df, data_config=self.data_config, full_day_skip=self.full_day_skip, predict_sequence_length=predict_sequence_length)
|
||||
test_dataset = NrvDataset(
|
||||
test_df,
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
)
|
||||
return self.get_dataloader(test_dataset, shuffle=False)
|
||||
|
||||
|
||||
def get_dataloaders(self, transform: bool = True, predict_sequence_length: int = 96):
|
||||
return self.get_train_dataloader(transform=transform, predict_sequence_length=predict_sequence_length), self.get_test_dataloader(transform=transform, predict_sequence_length=predict_sequence_length)
|
||||
|
||||
def inverse_transform(self, tensor: torch.Tensor):
|
||||
return self.nrv_scaler.inverse_transform(tensor.cpu().numpy()).reshape(-1)
|
||||
|
||||
|
||||
def get_dataloaders(
|
||||
self, transform: bool = True, predict_sequence_length: int = 96
|
||||
):
|
||||
return self.get_train_dataloader(
|
||||
transform=transform, predict_sequence_length=predict_sequence_length
|
||||
), self.get_test_dataloader(
|
||||
transform=transform, predict_sequence_length=predict_sequence_length
|
||||
)
|
||||
|
||||
def inverse_transform(self, input_data):
|
||||
try:
|
||||
if isinstance(input_data, torch.Tensor):
|
||||
if input_data.is_cuda:
|
||||
input_data = input_data.cpu()
|
||||
input_np = input_data.detach().numpy() # Convert to numpy array
|
||||
elif isinstance(input_data, np.ndarray):
|
||||
input_np = input_data
|
||||
else:
|
||||
raise TypeError("Input must be a PyTorch tensor or a NumPy array")
|
||||
|
||||
# Store the original shape
|
||||
original_shape = input_np.shape
|
||||
input_2d = input_np.reshape(-1, original_shape[-1])
|
||||
transformed_2d = self.nrv_scaler.inverse_transform(input_2d)
|
||||
|
||||
if isinstance(input_data, torch.Tensor):
|
||||
return torch.from_numpy(transformed_2d).view(original_shape)
|
||||
else:
|
||||
return transformed_2d.reshape(original_shape)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in inverse_transform: {e}") from e
|
||||
|
||||
def get_input_size(self):
|
||||
data_loader = self.get_train_dataloader()
|
||||
data_loader = self.get_train_dataloader(
|
||||
predict_sequence_length=self.output_size
|
||||
)
|
||||
input, _ = next(iter(data_loader))
|
||||
return input.shape[-1]
|
||||
|
||||
Reference in New Issue
Block a user