Added GRU results to thesis + intermediate samples of diffusion model

This commit is contained in:
Victor Mylle
2024-05-06 23:28:42 +00:00
parent d7f4c1849b
commit d9b6f34e97
25 changed files with 353 additions and 161 deletions

View File

@@ -22,6 +22,7 @@ def sample_diffusion(
beta_start=1e-4,
beta_end=0.02,
ts_length=96,
intermediate_samples: bool = False,
):
device = next(model.parameters()).device
beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)
@@ -34,6 +35,7 @@ def sample_diffusion(
inputs = inputs.repeat(n, 1, 1)
model.eval()
intermediate_samples_list = []
with torch.no_grad():
x = torch.randn(inputs.shape[0], ts_length).to(device)
for i in reversed(range(1, noise_steps)):
@@ -54,7 +56,17 @@ def sample_diffusion(
* (x - ((1 - _alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise)
+ torch.sqrt(_beta) * noise
)
# evenly spaces 4 intermediate samples to append between 1 and noise_steps
if intermediate_samples:
spacing = (noise_steps - 1) // 4
if i % spacing == 0:
intermediate_samples_list.append(x)
x = torch.clamp(x, -1.0, 1.0)
if len(intermediate_samples_list) > 0:
return x, intermediate_samples_list
return x
@@ -105,7 +117,13 @@ class DiffusionTrainer:
"""
return torch.randint(low=1, high=self.noise_steps, size=(n,))
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
def sample(
self,
model: DiffusionModel,
n: int,
inputs: torch.tensor,
intermediate_samples=False,
):
x = sample_diffusion(
model,
n,
@@ -114,7 +132,9 @@ class DiffusionTrainer:
self.beta_start,
self.beta_end,
self.ts_length,
intermediate_samples,
)
model.train()
return x
@@ -271,6 +291,87 @@ class DiffusionTrainer:
if task:
task.close()
def plot_from_samples(self, samples, target):
ci_99_upper = np.quantile(samples, 0.995, axis=0)
ci_99_lower = np.quantile(samples, 0.005, axis=0)
ci_95_upper = np.quantile(samples, 0.975, axis=0)
ci_95_lower = np.quantile(samples, 0.025, axis=0)
ci_90_upper = np.quantile(samples, 0.95, axis=0)
ci_90_lower = np.quantile(samples, 0.05, axis=0)
ci_50_lower = np.quantile(samples, 0.25, axis=0)
ci_50_upper = np.quantile(samples, 0.75, axis=0)
sns.set_theme()
time_steps = np.arange(0, 96)
fig, ax = plt.subplots(figsize=(20, 10))
ax.plot(
time_steps,
samples.mean(axis=0),
label="Mean of NRV samples",
linewidth=3,
)
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
ax.fill_between(
time_steps,
ci_99_lower,
ci_99_upper,
color="b",
alpha=0.2,
label="99% Interval",
)
ax.fill_between(
time_steps,
ci_95_lower,
ci_95_upper,
color="b",
alpha=0.2,
label="95% Interval",
)
ax.fill_between(
time_steps,
ci_90_lower,
ci_90_upper,
color="b",
alpha=0.2,
label="90% Interval",
)
ax.fill_between(
time_steps,
ci_50_lower,
ci_50_upper,
color="b",
alpha=0.2,
label="50% Interval",
)
ax.plot(target, label="Real NRV", linewidth=3)
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
ax.legend(
handles=[
ci_99_patch,
ci_95_patch,
ci_90_patch,
ci_50_patch,
ax.lines[0],
ax.lines[1],
]
)
ax.set_ylim([-1500, 1500])
return fig
def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch):
for actual_idx, idx in sample_indices.items():
features, target, _ = data_loader.dataset[idx]
@@ -280,86 +381,69 @@ class DiffusionTrainer:
self.model.eval()
with torch.no_grad():
samples = self.sample(self.model, 100, features).cpu().numpy()
samples, intermediates = (
self.sample(self.model, 100, features, True)
)
samples = samples.cpu().numpy()
samples = self.data_processor.inverse_transform(samples)
target = self.data_processor.inverse_transform(target)
ci_99_upper = np.quantile(samples, 0.995, axis=0)
ci_99_lower = np.quantile(samples, 0.005, axis=0)
# list to tensor intermediate samples
intermediates = torch.stack(intermediates)
ci_95_upper = np.quantile(samples, 0.975, axis=0)
ci_95_lower = np.quantile(samples, 0.025, axis=0)
ci_90_upper = np.quantile(samples, 0.95, axis=0)
ci_90_lower = np.quantile(samples, 0.05, axis=0)
ci_50_lower = np.quantile(samples, 0.25, axis=0)
ci_50_upper = np.quantile(samples, 0.75, axis=0)
sns.set_theme()
time_steps = np.arange(0, 96)
fig, ax = plt.subplots(figsize=(20, 10))
ax.plot(
time_steps,
samples.mean(axis=0),
label="Mean of NRV samples",
linewidth=3,
)
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
ax.fill_between(
time_steps,
ci_99_lower,
ci_99_upper,
color="b",
alpha=0.2,
label="99% Interval",
)
ax.fill_between(
time_steps,
ci_95_lower,
ci_95_upper,
color="b",
alpha=0.2,
label="95% Interval",
)
ax.fill_between(
time_steps,
ci_90_lower,
ci_90_upper,
color="b",
alpha=0.2,
label="90% Interval",
)
ax.fill_between(
time_steps,
ci_50_lower,
ci_50_upper,
color="b",
alpha=0.2,
label="50% Interval",
fig = self.plot_from_samples(samples, target)
intermediate_fig1 = self.plot_from_samples(
self.data_processor.inverse_transform(intermediates[0].cpu().numpy()), target
)
ax.plot(target, label="Real NRV", linewidth=3)
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
ax.legend(
handles=[
ci_99_patch,
ci_95_patch,
ci_90_patch,
ci_50_patch,
ax.lines[0],
ax.lines[1],
]
intermediate_fig2 = self.plot_from_samples(
self.data_processor.inverse_transform(intermediates[1].cpu().numpy()), target
)
ax.set_ylim([-1500, 1500])
intermediate_fig3 = self.plot_from_samples(
self.data_processor.inverse_transform(intermediates[2].cpu().numpy()), target
)
intermediate_fig4 = self.plot_from_samples(
self.data_processor.inverse_transform(intermediates[3].cpu().numpy()), target
)
# report the intermediate figs to clearml
task.get_logger().report_matplotlib_figure(
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
series=f"Sample intermediate 1",
iteration=epoch,
figure=intermediate_fig1,
report_image=True
)
task.get_logger().report_matplotlib_figure(
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
series=f"Sample intermediate 2",
iteration=epoch,
figure=intermediate_fig2,
report_image=True
)
task.get_logger().report_matplotlib_figure(
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
series=f"Sample intermediate 3",
iteration=epoch,
figure=intermediate_fig3,
report_image=True
)
task.get_logger().report_matplotlib_figure(
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
series=f"Sample intermediate 4",
iteration=epoch,
figure=intermediate_fig4,
report_image=True
)
task.get_logger().report_matplotlib_figure(
title="Training" if training else "Testing",

View File

@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
task = clearml_helper.get_task(
task_name="Diffusion Training: hidden_sizes=[1024, 1024, 1024], lr=0.0001, time_dim=8 + Load + PV + Wind + NP"
task_name="Diffusion Training: hidden_sizes=[256, 256, 256], lr=0.0001, time_dim=8"
)
task.execute_remotely(queue_name="default", exit_process=True)
@@ -19,19 +19,19 @@ from src.policies.PolicyEvaluator import PolicyEvaluator
data_config = DataConfig()
data_config.NRV_HISTORY = True
data_config.LOAD_HISTORY = True
data_config.LOAD_FORECAST = True
data_config.LOAD_HISTORY = False
data_config.LOAD_FORECAST = False
data_config.PV_FORECAST = True
data_config.PV_HISTORY = True
data_config.PV_FORECAST = False
data_config.PV_HISTORY = False
data_config.WIND_FORECAST = True
data_config.WIND_HISTORY = True
data_config.WIND_FORECAST = False
data_config.WIND_HISTORY = False
data_config.QUARTER = True
data_config.DAY_OF_WEEK = True
data_config.QUARTER = False
data_config.DAY_OF_WEEK = False
data_config.NOMINAL_NET_POSITION = True
data_config.NOMINAL_NET_POSITION = False
data_config = task.connect(data_config, name="data_features")
@@ -45,7 +45,7 @@ print("Input dim: ", inputDim)
model_parameters = {
"epochs": 15000,
"learning_rate": 0.0001,
"hidden_sizes": [1024, 1024, 1024],
"hidden_sizes": [256, 256, 256],
"time_dim": 8,
}

View File

@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
#### ClearML ####
clearml_helper = ClearMLHelper(project_name="Thesis/NAQR: GRU")
task = clearml_helper.get_task(task_name="NAQR: GRU (2 - 256) + Load")
task = clearml_helper.get_task(task_name="NAQR: GRU (8 - 512) + Load + PV + Wind + NP")
task.execute_remotely(queue_name="default", exit_process=True)
from src.policies.PolicyEvaluator import PolicyEvaluator
@@ -30,13 +30,13 @@ data_config.NRV_HISTORY = True
data_config.LOAD_HISTORY = True
data_config.LOAD_FORECAST = True
data_config.WIND_FORECAST = False
data_config.WIND_HISTORY = False
data_config.WIND_FORECAST = True
data_config.WIND_HISTORY = True
data_config.PV_FORECAST = False
data_config.PV_HISTORY = False
data_config.PV_FORECAST = True
data_config.PV_HISTORY = True
data_config.NOMINAL_NET_POSITION = False
data_config.NOMINAL_NET_POSITION = True
data_config = task.connect(data_config, name="data_features")