Saving diffusion model on better CRPS score
This commit is contained in:
@@ -28,6 +28,8 @@ class DiffusionTrainer:
|
||||
self.alpha = 1. - self.beta
|
||||
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
||||
|
||||
self.best_score = None
|
||||
|
||||
def noise_time_series(self, x: torch.tensor, t: int):
|
||||
""" Add noise to time series
|
||||
Args:
|
||||
@@ -93,6 +95,7 @@ class DiffusionTrainer:
|
||||
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
||||
|
||||
def train(self, epochs: int, learning_rate: float, task: Task = None):
|
||||
self.best_score = None
|
||||
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
||||
criterion = nn.MSELoss()
|
||||
self.model.to(self.device)
|
||||
@@ -220,6 +223,9 @@ class DiffusionTrainer:
|
||||
all_crps = np.array(all_crps)
|
||||
mean_crps = all_crps.mean()
|
||||
|
||||
if self.best_score is None or mean_crps < self.best_score:
|
||||
self.save_checkpoint(mean_crps, task, epoch)
|
||||
|
||||
if task:
|
||||
task.get_logger().report_scalar(
|
||||
title="CRPS",
|
||||
@@ -227,4 +233,11 @@ class DiffusionTrainer:
|
||||
value=mean_crps,
|
||||
iteration=epoch
|
||||
)
|
||||
|
||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||
torch.save(self.model.state_dict(), "checkpoint.pt")
|
||||
task.update_output_model(
|
||||
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
||||
)
|
||||
self.best_score = val_loss
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ data_config.DAY_OF_WEEK = False
|
||||
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
data_config = Task.connect(data_config, name="data_features")
|
||||
data_config = task.connect(data_config, name="data_features")
|
||||
|
||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
||||
data_processor.set_batch_size(8192)
|
||||
|
||||
Reference in New Issue
Block a user