Added crps + profit logging and updated plots for non autoregressive models
This commit is contained in:
@@ -173,6 +173,9 @@ class DiffusionTrainer:
|
||||
criterion = nn.MSELoss()
|
||||
self.model.to(self.device)
|
||||
|
||||
early_stopping = 0
|
||||
best_crps = None
|
||||
|
||||
if task:
|
||||
self.init_clearml_task(task)
|
||||
|
||||
@@ -204,7 +207,16 @@ class DiffusionTrainer:
|
||||
running_loss /= len(train_loader.dataset)
|
||||
|
||||
if epoch % 40 == 0 and epoch != 0:
|
||||
self.test(test_loader, epoch, task)
|
||||
crps = self.test(test_loader, epoch, task)
|
||||
|
||||
if best_crps is None or crps < best_crps:
|
||||
best_crps = crps
|
||||
early_stopping = 0
|
||||
else:
|
||||
early_stopping += 1
|
||||
|
||||
if early_stopping > 5:
|
||||
break
|
||||
|
||||
if task:
|
||||
task.get_logger().report_scalar(
|
||||
@@ -222,6 +234,13 @@ class DiffusionTrainer:
|
||||
task, False, test_loader, test_sample_indices, epoch
|
||||
)
|
||||
|
||||
# load the best model
|
||||
self.model = torch.load("checkpoint.pt")
|
||||
self.model.to(self.device)
|
||||
|
||||
self.test(test_loader, None, task)
|
||||
self.policy_evaluator.plot_profits_table()
|
||||
|
||||
if task:
|
||||
task.close()
|
||||
|
||||
@@ -329,7 +348,6 @@ class DiffusionTrainer:
|
||||
generated_samples = {}
|
||||
for inputs, targets, idx_batch in data_loader:
|
||||
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
||||
print(inputs.shape, targets.shape)
|
||||
|
||||
number_of_samples = 100
|
||||
sample = self.sample(self.model, number_of_samples, inputs)
|
||||
@@ -388,6 +406,8 @@ class DiffusionTrainer:
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
return mean_crps
|
||||
|
||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||
torch.save(self.model, "checkpoint.pt")
|
||||
task.update_output_model(
|
||||
|
||||
Reference in New Issue
Block a user