Saving diffusion model on better CRPS score

This commit is contained in:
Victor Mylle
2023-12-29 10:48:07 +00:00
parent 81231b9266
commit da3ab3d5b3
2 changed files with 14 additions and 1 deletions

View File

@@ -28,6 +28,8 @@ class DiffusionTrainer:
self.alpha = 1. - self.beta
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
self.best_score = None
def noise_time_series(self, x: torch.tensor, t: int):
""" Add noise to time series
Args:
@@ -93,6 +95,7 @@ class DiffusionTrainer:
self.data_processor = task.connect(self.data_processor, name="data_processor")
def train(self, epochs: int, learning_rate: float, task: Task = None):
self.best_score = None
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
self.model.to(self.device)
@@ -220,6 +223,9 @@ class DiffusionTrainer:
all_crps = np.array(all_crps)
mean_crps = all_crps.mean()
if self.best_score is None or mean_crps < self.best_score:
self.save_checkpoint(mean_crps, task, epoch)
if task:
task.get_logger().report_scalar(
title="CRPS",
@@ -227,4 +233,11 @@ class DiffusionTrainer:
value=mean_crps,
iteration=epoch
)
def save_checkpoint(self, val_loss, task, iteration: int):
torch.save(self.model.state_dict(), "checkpoint.pt")
task.update_output_model(
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
)
self.best_score = val_loss

View File

@@ -35,7 +35,7 @@ data_config.DAY_OF_WEEK = False
data_config.NOMINAL_NET_POSITION = True
data_config = Task.connect(data_config, name="data_features")
data_config = task.connect(data_config, name="data_features")
data_processor = DataProcessor(data_config, path="", lstm=False)
data_processor.set_batch_size(8192)