diff --git a/src/trainers/diffusion_trainer.py b/src/trainers/diffusion_trainer.py index f4dcfe7..ebe668b 100644 --- a/src/trainers/diffusion_trainer.py +++ b/src/trainers/diffusion_trainer.py @@ -28,6 +28,8 @@ class DiffusionTrainer: self.alpha = 1. - self.beta self.alpha_hat = torch.cumprod(self.alpha, dim=0) + self.best_score = None + def noise_time_series(self, x: torch.tensor, t: int): """ Add noise to time series Args: @@ -93,6 +95,7 @@ class DiffusionTrainer: self.data_processor = task.connect(self.data_processor, name="data_processor") def train(self, epochs: int, learning_rate: float, task: Task = None): + self.best_score = None optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) criterion = nn.MSELoss() self.model.to(self.device) @@ -220,6 +223,9 @@ class DiffusionTrainer: all_crps = np.array(all_crps) mean_crps = all_crps.mean() + if self.best_score is None or mean_crps < self.best_score: + self.save_checkpoint(mean_crps, task, epoch) + if task: task.get_logger().report_scalar( title="CRPS", @@ -227,4 +233,11 @@ class DiffusionTrainer: value=mean_crps, iteration=epoch ) + + def save_checkpoint(self, val_loss, task, iteration: int): + torch.save(self.model.state_dict(), "checkpoint.pt") + task.update_output_model( + model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False + ) + self.best_score = val_loss diff --git a/src/training_scripts/diffusion_training.py b/src/training_scripts/diffusion_training.py index 83f4b85..fd384f3 100644 --- a/src/training_scripts/diffusion_training.py +++ b/src/training_scripts/diffusion_training.py @@ -35,7 +35,7 @@ data_config.DAY_OF_WEEK = False data_config.NOMINAL_NET_POSITION = True -data_config = Task.connect(data_config, name="data_features") +data_config = task.connect(data_config, name="data_features") data_processor = DataProcessor(data_config, path="", lstm=False) data_processor.set_batch_size(8192)