Using val set for diffusion trainer

This commit is contained in:
2024-05-17 13:24:02 +02:00
parent 615f9486bc
commit 11ae0e1949
13 changed files with 287 additions and 175 deletions

View File

@@ -1,8 +1,12 @@
from src.utils.clearml import ClearMLHelper
#### ClearML ####
clearml_helper = ClearMLHelper(project_name="Thesis/NAQR: GRU")
task = clearml_helper.get_task(task_name="NAQR: GRU (8 - 512) + Load + PV + Wind + NP")
clearml_helper = ClearMLHelper(
project_name="Thesis/NAQR: Non Linear (4 - 256) + Load + PV + Wind + NP"
)
task = clearml_helper.get_task(
task_name="NAQR: Non Linear (4 - 256) + Load + PV + Wind + NP"
)
task.execute_remotely(queue_name="default", exit_process=True)
from src.policies.PolicyEvaluator import PolicyEvaluator
@@ -119,32 +123,32 @@ trainer.plot_every(20)
trainer.train(task=task, epochs=epochs, remotely=True)
### Policy Evaluation ###
# idx_samples = trainer.test_set_samples
# _, test_loader = trainer.data_processor.get_dataloaders(
# predict_sequence_length=trainer.model.output_size, full_day_skip=False
# )
idx_samples = trainer.test_set_samples
_, test_loader = trainer.data_processor.get_dataloaders(
predict_sequence_length=trainer.model.output_size, full_day_skip=False
)
# policy_evaluator.evaluate_test_set(idx_samples, test_loader)
# policy_evaluator.plot_profits_table()
# policy_evaluator.plot_thresholds_per_day()
policy_evaluator.evaluate_test_set(idx_samples, test_loader)
policy_evaluator.plot_profits_table()
policy_evaluator.plot_thresholds_per_day()
# optimal_penalty, profit, charge_cycles = (
# policy_evaluator.optimize_penalty_for_target_charge_cycles(
# idx_samples=idx_samples,
# test_loader=test_loader,
# initial_penalty=1000,
# target_charge_cycles=283,
# learning_rate=15,
# max_iterations=150,
# tolerance=1,
# )
# )
optimal_penalty, profit, charge_cycles = (
policy_evaluator.optimize_penalty_for_target_charge_cycles(
idx_samples=idx_samples,
test_loader=test_loader,
initial_penalty=1000,
target_charge_cycles=283,
learning_rate=15,
max_iterations=150,
tolerance=1,
)
)
# print(
# f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}"
# )
# task.get_logger().report_single_value(name="Optimal Penalty", value=optimal_penalty)
# task.get_logger().report_single_value(name="Optimal Profit", value=profit)
# task.get_logger().report_single_value(name="Optimal Charge Cycles", value=charge_cycles)
print(
f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}"
)
task.get_logger().report_single_value(name="Optimal Penalty", value=optimal_penalty)
task.get_logger().report_single_value(name="Optimal Profit", value=profit)
task.get_logger().report_single_value(name="Optimal Charge Cycles", value=charge_cycles)
task.close()