Using val set for diffusion trainer

This commit is contained in:
2024-05-17 13:24:02 +02:00
parent 615f9486bc
commit 11ae0e1949
13 changed files with 287 additions and 175 deletions

View File

@@ -49,6 +49,10 @@ class DataProcessor:
-np.inf,
datetime(year=2022, month=11, day=30, tzinfo=pytz.UTC),
)
self.val_range = (
datetime(year=2022, month=10, day=1, tzinfo=pytz.UTC),
datetime(year=2022, month=11, day=30, tzinfo=pytz.UTC),
)
self.test_range = (datetime(year=2023, month=1, day=1, tzinfo=pytz.UTC), np.inf)
self.update_range_str()
@@ -227,14 +231,23 @@ class DataProcessor:
transform: bool = True,
predict_sequence_length: int = 96,
shuffle: bool = True,
with_validation: bool = False,
):
train_df = self.all_features.copy()
train_range = self.train_range
if with_validation:
train_range = (
self.train_range[0],
self.val_range[0] - pd.Timedelta(days=1),
)
if self.train_range[0] != -np.inf:
train_df = train_df[(train_df["datetime"] >= self.train_range[0])]
train_df = train_df[(train_df["datetime"] >= train_range[0])]
if self.train_range[1] != np.inf:
train_df = train_df[(train_df["datetime"] <= self.train_range[1])]
train_df = train_df[(train_df["datetime"] <= train_range[1])]
if transform:
train_df["nrv"] = self.nrv_scaler.fit_transform(
@@ -276,6 +289,58 @@ class DataProcessor:
)
return self.get_dataloader(train_dataset, shuffle=shuffle)
def get_val_dataloader(
self,
transform: bool = True,
predict_sequence_length: int = 96,
full_day_skip: bool = False,
):
val_df = self.all_features.copy()
if self.test_range[0] != -np.inf:
val_df = val_df[(val_df["datetime"] >= self.val_range[0])]
if self.test_range[1] != np.inf:
val_df = val_df[(val_df["datetime"] <= self.val_range[1])]
if transform:
val_df["nrv"] = self.nrv_scaler.transform(
val_df["nrv"].values.reshape(-1, 1)
).reshape(-1)
val_df["load_forecast"] = self.load_forecast_scaler.transform(
val_df["load_forecast"].values.reshape(-1, 1)
).reshape(-1)
val_df["total_load"] = self.load_forecast_scaler.transform(
val_df["total_load"].values.reshape(-1, 1)
).reshape(-1)
val_df["pv_forecast"] = self.pv_forecast_scaler.transform(
val_df["pv_forecast"].values.reshape(-1, 1)
).reshape(-1)
val_df["pv_history"] = self.pv_forecast_scaler.transform(
val_df["pv_history"].values.reshape(-1, 1)
).reshape(-1)
val_df["wind_forecast"] = self.wind_forecast_scaler.transform(
val_df["wind_forecast"].values.reshape(-1, 1)
).reshape(-1)
val_df["wind_history"] = self.wind_forecast_scaler.transform(
val_df["wind_history"].values.reshape(-1, 1)
).reshape(-1)
val_df["nominal_net_position"] = self.nominal_net_position_scaler.transform(
val_df["nominal_net_position"].values.reshape(-1, 1)
).reshape(-1)
val_dataset = NrvDataset(
val_df,
data_config=self.data_config,
full_day_skip=self.full_day_skip or full_day_skip,
predict_sequence_length=predict_sequence_length,
lstm=self.lstm,
)
return self.get_dataloader(val_dataset, shuffle=False)
def get_test_dataloader(
self,
transform: bool = True,
@@ -335,14 +400,35 @@ class DataProcessor:
transform: bool = True,
predict_sequence_length: int = 96,
full_day_skip: bool = False,
validation: bool = False,
):
return self.get_train_dataloader(
transform=transform, predict_sequence_length=predict_sequence_length
), self.get_test_dataloader(
transform=transform,
predict_sequence_length=predict_sequence_length,
full_day_skip=full_day_skip,
)
if not validation:
return self.get_train_dataloader(
transform=transform, predict_sequence_length=predict_sequence_length
), self.get_test_dataloader(
transform=transform,
predict_sequence_length=predict_sequence_length,
full_day_skip=full_day_skip,
)
else:
return (
self.get_train_dataloader(
transform=transform,
predict_sequence_length=predict_sequence_length,
with_validation=True,
),
self.get_val_dataloader(
transform=transform,
predict_sequence_length=predict_sequence_length,
full_day_skip=full_day_skip,
),
self.get_test_dataloader(
transform=transform,
predict_sequence_length=predict_sequence_length,
full_day_skip=full_day_skip,
),
)
def inverse_transform(self, input_data):
try: