Worked further on thesis
This commit is contained in:
@@ -59,13 +59,16 @@ def sample_diffusion(
|
||||
|
||||
# evenly spaces 4 intermediate samples to append between 1 and noise_steps
|
||||
if intermediate_samples:
|
||||
spacing = (noise_steps - 1) // 4
|
||||
if i % spacing == 0:
|
||||
first_quarter_end = (noise_steps - 1) // 4
|
||||
spacing = (first_quarter_end - 1) // 4
|
||||
|
||||
# save 1, 1 + spacing, 1 + 2*spacing, 1 + 3*spacing
|
||||
if i % spacing == 1 and i <= first_quarter_end:
|
||||
intermediate_samples_list.append(x)
|
||||
|
||||
x = torch.clamp(x, -1.0, 1.0)
|
||||
if len(intermediate_samples_list) > 0:
|
||||
return x, intermediate_samples_list
|
||||
return x, intermediate_samples_list[-4:]
|
||||
|
||||
return x
|
||||
|
||||
@@ -81,7 +84,7 @@ class DiffusionTrainer:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
self.noise_steps = 1000
|
||||
self.noise_steps = 300
|
||||
self.beta_start = 0.0001
|
||||
self.beta_end = 0.02
|
||||
self.ts_length = 96
|
||||
@@ -260,6 +263,9 @@ class DiffusionTrainer:
|
||||
self.model = torch.load("checkpoint.pt")
|
||||
self.model.to(self.device)
|
||||
|
||||
self.debug_plots(task, True, train_loader, train_sample_indices, -1)
|
||||
self.debug_plots(task, False, test_loader, test_sample_indices, -1)
|
||||
|
||||
_, generated_sampels = self.test(test_loader, -1, task)
|
||||
# self.policy_evaluator.plot_profits_table()
|
||||
if self.policy_evaluator:
|
||||
@@ -371,7 +377,6 @@ class DiffusionTrainer:
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch):
|
||||
for actual_idx, idx in sample_indices.items():
|
||||
features, target, _ = data_loader.dataset[idx]
|
||||
@@ -381,69 +386,93 @@ class DiffusionTrainer:
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
samples, intermediates = (
|
||||
self.sample(self.model, 100, features, True)
|
||||
)
|
||||
samples, intermediates = self.sample(self.model, 100, features, True)
|
||||
samples = samples.cpu().numpy()
|
||||
samples = self.data_processor.inverse_transform(samples)
|
||||
target = self.data_processor.inverse_transform(target)
|
||||
|
||||
# list to tensor intermediate samples
|
||||
intermediates = torch.stack(intermediates)
|
||||
if epoch == -1:
|
||||
# list to tensor intermediate samples
|
||||
intermediates = torch.stack(intermediates)
|
||||
|
||||
intermediate_fig1 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(
|
||||
intermediates[0].cpu().numpy()
|
||||
),
|
||||
target,
|
||||
)
|
||||
|
||||
intermediate_fig2 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(
|
||||
intermediates[1].cpu().numpy()
|
||||
),
|
||||
target,
|
||||
)
|
||||
|
||||
intermediate_fig3 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(
|
||||
intermediates[2].cpu().numpy()
|
||||
),
|
||||
target,
|
||||
)
|
||||
|
||||
intermediate_fig4 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(
|
||||
intermediates[3].cpu().numpy()
|
||||
),
|
||||
target,
|
||||
)
|
||||
|
||||
# report the intermediate figs to clearml
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=(
|
||||
f"Training Intermediates {actual_idx}"
|
||||
if training
|
||||
else f"Testing Intermediates {actual_idx}"
|
||||
),
|
||||
series=f"Sample intermediate 1",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig1,
|
||||
report_image=True,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=(
|
||||
f"Training Intermediates {actual_idx}"
|
||||
if training
|
||||
else f"Testing Intermediates {actual_idx}"
|
||||
),
|
||||
series=f"Sample intermediate 2",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig2,
|
||||
report_image=True,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=(
|
||||
f"Training Intermediates {actual_idx}"
|
||||
if training
|
||||
else f"Testing Intermediates {actual_idx}"
|
||||
),
|
||||
series=f"Sample intermediate 3",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig3,
|
||||
report_image=True,
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=(
|
||||
f"Training Intermediates {actual_idx}"
|
||||
if training
|
||||
else f"Testing Intermediates {actual_idx}"
|
||||
),
|
||||
series=f"Sample intermediate 4",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig4,
|
||||
report_image=True,
|
||||
)
|
||||
|
||||
fig = self.plot_from_samples(samples, target)
|
||||
intermediate_fig1 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[0].cpu().numpy()), target
|
||||
)
|
||||
|
||||
intermediate_fig2 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[1].cpu().numpy()), target
|
||||
)
|
||||
|
||||
intermediate_fig3 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[2].cpu().numpy()), target
|
||||
)
|
||||
|
||||
intermediate_fig4 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[3].cpu().numpy()), target
|
||||
)
|
||||
|
||||
|
||||
# report the intermediate figs to clearml
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 1",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig1,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 2",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig2,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 3",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig3,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 4",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig4,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if training else "Testing",
|
||||
|
||||
@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
|
||||
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(
|
||||
task_name="Diffusion Training: hidden_sizes=[256, 256, 256], lr=0.0001, time_dim=8"
|
||||
task_name="Diffusion Training: hidden_sizes=[1024, 1024, 1024, 1024] (300 steps), lr=0.0001, time_dim=8 + Load + Wind + PV + NP"
|
||||
)
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
@@ -19,19 +19,16 @@ from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
data_config = DataConfig()
|
||||
data_config.NRV_HISTORY = True
|
||||
|
||||
data_config.LOAD_HISTORY = False
|
||||
data_config.LOAD_FORECAST = False
|
||||
data_config.LOAD_HISTORY = True
|
||||
data_config.LOAD_FORECAST = True
|
||||
|
||||
data_config.PV_FORECAST = False
|
||||
data_config.PV_HISTORY = False
|
||||
data_config.PV_FORECAST = True
|
||||
data_config.PV_HISTORY = True
|
||||
|
||||
data_config.WIND_FORECAST = False
|
||||
data_config.WIND_HISTORY = False
|
||||
data_config.WIND_FORECAST = True
|
||||
data_config.WIND_HISTORY = True
|
||||
|
||||
data_config.QUARTER = False
|
||||
data_config.DAY_OF_WEEK = False
|
||||
|
||||
data_config.NOMINAL_NET_POSITION = False
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
data_config = task.connect(data_config, name="data_features")
|
||||
|
||||
@@ -45,7 +42,7 @@ print("Input dim: ", inputDim)
|
||||
model_parameters = {
|
||||
"epochs": 15000,
|
||||
"learning_rate": 0.0001,
|
||||
"hidden_sizes": [256, 256, 256],
|
||||
"hidden_sizes": [1024, 1024, 1024, 1024],
|
||||
"time_dim": 8,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user