Updated training scripts
This commit is contained in:
@@ -85,6 +85,8 @@ class DiffusionTrainer:
|
||||
self.best_score = None
|
||||
self.policy_evaluator = policy_evaluator
|
||||
|
||||
self.prev_optimal_penalty = 0
|
||||
|
||||
def noise_time_series(self, x: torch.tensor, t: int):
|
||||
"""Add noise to time series
|
||||
Args:
|
||||
@@ -206,8 +208,8 @@ class DiffusionTrainer:
|
||||
|
||||
running_loss /= len(train_loader.dataset)
|
||||
|
||||
if epoch % 40 == 0 and epoch != 0:
|
||||
crps = self.test(test_loader, epoch, task)
|
||||
if epoch % 150 == 0 and epoch != 0:
|
||||
crps, _ = self.test(test_loader, epoch, task)
|
||||
|
||||
if best_crps is None or crps < best_crps:
|
||||
best_crps = crps
|
||||
@@ -215,7 +217,7 @@ class DiffusionTrainer:
|
||||
else:
|
||||
early_stopping += 1
|
||||
|
||||
if early_stopping > 5:
|
||||
if early_stopping > 15:
|
||||
break
|
||||
|
||||
if task:
|
||||
@@ -238,8 +240,32 @@ class DiffusionTrainer:
|
||||
self.model = torch.load("checkpoint.pt")
|
||||
self.model.to(self.device)
|
||||
|
||||
self.test(test_loader, None, task)
|
||||
self.policy_evaluator.plot_profits_table()
|
||||
_, generated_sampels = self.test(test_loader, None, task)
|
||||
# self.policy_evaluator.plot_profits_table()
|
||||
|
||||
optimal_penalty, profit, charge_cycles = (
|
||||
self.policy_evaluator.optimize_penalty_for_target_charge_cycles(
|
||||
idx_samples=generated_sampels,
|
||||
test_loader=test_loader,
|
||||
initial_penalty=900,
|
||||
target_charge_cycles=283,
|
||||
learning_rate=1,
|
||||
max_iterations=50,
|
||||
tolerance=1,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}"
|
||||
)
|
||||
|
||||
task.get_logger().report_single_value(
|
||||
name="Optimal Penalty", value=optimal_penalty
|
||||
)
|
||||
task.get_logger().report_single_value(name="Optimal Profit", value=profit)
|
||||
task.get_logger().report_single_value(
|
||||
name="Optimal Charge Cycles", value=charge_cycles
|
||||
)
|
||||
|
||||
if task:
|
||||
task.close()
|
||||
@@ -332,6 +358,8 @@ class DiffusionTrainer:
|
||||
]
|
||||
)
|
||||
|
||||
ax.set_ylim([-1500, 1500])
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if training else "Testing",
|
||||
series=f"Sample {actual_idx}",
|
||||
@@ -341,6 +369,25 @@ class DiffusionTrainer:
|
||||
|
||||
plt.close()
|
||||
|
||||
# plot some samples for the nrv genearations (10 samples) (scale -1500 to 1500)
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
for i in range(10):
|
||||
ax.plot(samples[i], label=f"Sample {i}")
|
||||
|
||||
ax.plot(target, label="Real NRV", linewidth=3)
|
||||
ax.legend()
|
||||
ax.set_ylim([-1500, 1500])
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training Samples" if training else "Testing Samples",
|
||||
series=f"Sample {actual_idx} samples",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
report_interactive=False,
|
||||
)
|
||||
|
||||
plt.close()
|
||||
|
||||
def test(
|
||||
self, data_loader: torch.utils.data.DataLoader, epoch: int, task: Task = None
|
||||
):
|
||||
@@ -385,28 +432,39 @@ class DiffusionTrainer:
|
||||
predict_sequence_length=self.ts_length, full_day_skip=True
|
||||
)
|
||||
|
||||
self.policy_evaluator.evaluate_test_set(generated_samples, test_loader)
|
||||
|
||||
df = self.policy_evaluator.get_profits_as_scalars()
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit",
|
||||
series=f"penalty_{row['Penalty']}",
|
||||
value=row["Total Profit"],
|
||||
iteration=epoch,
|
||||
optimal_penalty, profit, charge_cycles = (
|
||||
self.policy_evaluator.optimize_penalty_for_target_charge_cycles(
|
||||
idx_samples=generated_samples,
|
||||
test_loader=test_loader,
|
||||
initial_penalty=self.prev_optimal_penalty,
|
||||
target_charge_cycles=283,
|
||||
learning_rate=1,
|
||||
max_iterations=50,
|
||||
tolerance=1,
|
||||
)
|
||||
)
|
||||
|
||||
df = self.policy_evaluator.get_profits_till_400()
|
||||
for idx, row in df.iterrows():
|
||||
task.get_logger().report_scalar(
|
||||
title="Profit_till_400",
|
||||
series=f"penalty_{row['Penalty']}",
|
||||
value=row["Profit_till_400"],
|
||||
iteration=epoch,
|
||||
)
|
||||
self.prev_optimal_penalty = optimal_penalty
|
||||
|
||||
return mean_crps
|
||||
task.get_logger().report_scalar(
|
||||
title="Optimal Penalty",
|
||||
series="test",
|
||||
value=optimal_penalty,
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
task.get_logger().report_scalar(
|
||||
title="Optimal Profit", series="test", value=profit, iteration=epoch
|
||||
)
|
||||
|
||||
task.get_logger().report_scalar(
|
||||
title="Optimal Charge Cycles",
|
||||
series="test",
|
||||
value=charge_cycles,
|
||||
iteration=epoch,
|
||||
)
|
||||
|
||||
return mean_crps, generated_samples
|
||||
|
||||
def save_checkpoint(self, val_loss, task, iteration: int):
|
||||
torch.save(self.model, "checkpoint.pt")
|
||||
|
||||
Reference in New Issue
Block a user