Added crps + profit logging and updated plots for non autoregressive models

This commit is contained in:
2024-02-28 17:12:51 +01:00
parent 420c9dc6ac
commit fe1e388ffb
6 changed files with 253 additions and 70 deletions

View File

@@ -173,6 +173,9 @@ class DiffusionTrainer:
criterion = nn.MSELoss()
self.model.to(self.device)
early_stopping = 0
best_crps = None
if task:
self.init_clearml_task(task)
@@ -204,7 +207,16 @@ class DiffusionTrainer:
running_loss /= len(train_loader.dataset)
if epoch % 40 == 0 and epoch != 0:
self.test(test_loader, epoch, task)
crps = self.test(test_loader, epoch, task)
if best_crps is None or crps < best_crps:
best_crps = crps
early_stopping = 0
else:
early_stopping += 1
if early_stopping > 5:
break
if task:
task.get_logger().report_scalar(
@@ -222,6 +234,13 @@ class DiffusionTrainer:
task, False, test_loader, test_sample_indices, epoch
)
# load the best model
self.model = torch.load("checkpoint.pt")
self.model.to(self.device)
self.test(test_loader, None, task)
self.policy_evaluator.plot_profits_table()
if task:
task.close()
@@ -329,7 +348,6 @@ class DiffusionTrainer:
generated_samples = {}
for inputs, targets, idx_batch in data_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
print(inputs.shape, targets.shape)
number_of_samples = 100
sample = self.sample(self.model, number_of_samples, inputs)
@@ -388,6 +406,8 @@ class DiffusionTrainer:
iteration=epoch,
)
return mean_crps
def save_checkpoint(self, val_loss, task, iteration: int):
torch.save(self.model, "checkpoint.pt")
task.update_output_model(