Using val set for diffusion trainer
This commit is contained in:
@@ -49,6 +49,10 @@ class DataProcessor:
|
||||
-np.inf,
|
||||
datetime(year=2022, month=11, day=30, tzinfo=pytz.UTC),
|
||||
)
|
||||
self.val_range = (
|
||||
datetime(year=2022, month=10, day=1, tzinfo=pytz.UTC),
|
||||
datetime(year=2022, month=11, day=30, tzinfo=pytz.UTC),
|
||||
)
|
||||
self.test_range = (datetime(year=2023, month=1, day=1, tzinfo=pytz.UTC), np.inf)
|
||||
|
||||
self.update_range_str()
|
||||
@@ -227,14 +231,23 @@ class DataProcessor:
|
||||
transform: bool = True,
|
||||
predict_sequence_length: int = 96,
|
||||
shuffle: bool = True,
|
||||
with_validation: bool = False,
|
||||
):
|
||||
train_df = self.all_features.copy()
|
||||
|
||||
train_range = self.train_range
|
||||
|
||||
if with_validation:
|
||||
train_range = (
|
||||
self.train_range[0],
|
||||
self.val_range[0] - pd.Timedelta(days=1),
|
||||
)
|
||||
|
||||
if self.train_range[0] != -np.inf:
|
||||
train_df = train_df[(train_df["datetime"] >= self.train_range[0])]
|
||||
train_df = train_df[(train_df["datetime"] >= train_range[0])]
|
||||
|
||||
if self.train_range[1] != np.inf:
|
||||
train_df = train_df[(train_df["datetime"] <= self.train_range[1])]
|
||||
train_df = train_df[(train_df["datetime"] <= train_range[1])]
|
||||
|
||||
if transform:
|
||||
train_df["nrv"] = self.nrv_scaler.fit_transform(
|
||||
@@ -276,6 +289,58 @@ class DataProcessor:
|
||||
)
|
||||
return self.get_dataloader(train_dataset, shuffle=shuffle)
|
||||
|
||||
def get_val_dataloader(
|
||||
self,
|
||||
transform: bool = True,
|
||||
predict_sequence_length: int = 96,
|
||||
full_day_skip: bool = False,
|
||||
):
|
||||
val_df = self.all_features.copy()
|
||||
|
||||
if self.test_range[0] != -np.inf:
|
||||
val_df = val_df[(val_df["datetime"] >= self.val_range[0])]
|
||||
|
||||
if self.test_range[1] != np.inf:
|
||||
val_df = val_df[(val_df["datetime"] <= self.val_range[1])]
|
||||
|
||||
if transform:
|
||||
val_df["nrv"] = self.nrv_scaler.transform(
|
||||
val_df["nrv"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
val_df["load_forecast"] = self.load_forecast_scaler.transform(
|
||||
val_df["load_forecast"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
val_df["total_load"] = self.load_forecast_scaler.transform(
|
||||
val_df["total_load"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
|
||||
val_df["pv_forecast"] = self.pv_forecast_scaler.transform(
|
||||
val_df["pv_forecast"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
|
||||
val_df["pv_history"] = self.pv_forecast_scaler.transform(
|
||||
val_df["pv_history"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
|
||||
val_df["wind_forecast"] = self.wind_forecast_scaler.transform(
|
||||
val_df["wind_forecast"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
val_df["wind_history"] = self.wind_forecast_scaler.transform(
|
||||
val_df["wind_history"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
val_df["nominal_net_position"] = self.nominal_net_position_scaler.transform(
|
||||
val_df["nominal_net_position"].values.reshape(-1, 1)
|
||||
).reshape(-1)
|
||||
|
||||
val_dataset = NrvDataset(
|
||||
val_df,
|
||||
data_config=self.data_config,
|
||||
full_day_skip=self.full_day_skip or full_day_skip,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
lstm=self.lstm,
|
||||
)
|
||||
return self.get_dataloader(val_dataset, shuffle=False)
|
||||
|
||||
def get_test_dataloader(
|
||||
self,
|
||||
transform: bool = True,
|
||||
@@ -335,14 +400,35 @@ class DataProcessor:
|
||||
transform: bool = True,
|
||||
predict_sequence_length: int = 96,
|
||||
full_day_skip: bool = False,
|
||||
validation: bool = False,
|
||||
):
|
||||
return self.get_train_dataloader(
|
||||
transform=transform, predict_sequence_length=predict_sequence_length
|
||||
), self.get_test_dataloader(
|
||||
transform=transform,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
full_day_skip=full_day_skip,
|
||||
)
|
||||
|
||||
if not validation:
|
||||
return self.get_train_dataloader(
|
||||
transform=transform, predict_sequence_length=predict_sequence_length
|
||||
), self.get_test_dataloader(
|
||||
transform=transform,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
full_day_skip=full_day_skip,
|
||||
)
|
||||
else:
|
||||
return (
|
||||
self.get_train_dataloader(
|
||||
transform=transform,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
with_validation=True,
|
||||
),
|
||||
self.get_val_dataloader(
|
||||
transform=transform,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
full_day_skip=full_day_skip,
|
||||
),
|
||||
self.get_test_dataloader(
|
||||
transform=transform,
|
||||
predict_sequence_length=predict_sequence_length,
|
||||
full_day_skip=full_day_skip,
|
||||
),
|
||||
)
|
||||
|
||||
def inverse_transform(self, input_data):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user