Added GRU results to thesis + intermediate samples of diffusion model
This commit is contained in:
@@ -22,6 +22,7 @@ def sample_diffusion(
|
||||
beta_start=1e-4,
|
||||
beta_end=0.02,
|
||||
ts_length=96,
|
||||
intermediate_samples: bool = False,
|
||||
):
|
||||
device = next(model.parameters()).device
|
||||
beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)
|
||||
@@ -34,6 +35,7 @@ def sample_diffusion(
|
||||
inputs = inputs.repeat(n, 1, 1)
|
||||
|
||||
model.eval()
|
||||
intermediate_samples_list = []
|
||||
with torch.no_grad():
|
||||
x = torch.randn(inputs.shape[0], ts_length).to(device)
|
||||
for i in reversed(range(1, noise_steps)):
|
||||
@@ -54,7 +56,17 @@ def sample_diffusion(
|
||||
* (x - ((1 - _alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise)
|
||||
+ torch.sqrt(_beta) * noise
|
||||
)
|
||||
|
||||
# evenly spaces 4 intermediate samples to append between 1 and noise_steps
|
||||
if intermediate_samples:
|
||||
spacing = (noise_steps - 1) // 4
|
||||
if i % spacing == 0:
|
||||
intermediate_samples_list.append(x)
|
||||
|
||||
x = torch.clamp(x, -1.0, 1.0)
|
||||
if len(intermediate_samples_list) > 0:
|
||||
return x, intermediate_samples_list
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@@ -105,7 +117,13 @@ class DiffusionTrainer:
|
||||
"""
|
||||
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
||||
|
||||
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
||||
def sample(
|
||||
self,
|
||||
model: DiffusionModel,
|
||||
n: int,
|
||||
inputs: torch.tensor,
|
||||
intermediate_samples=False,
|
||||
):
|
||||
x = sample_diffusion(
|
||||
model,
|
||||
n,
|
||||
@@ -114,7 +132,9 @@ class DiffusionTrainer:
|
||||
self.beta_start,
|
||||
self.beta_end,
|
||||
self.ts_length,
|
||||
intermediate_samples,
|
||||
)
|
||||
|
||||
model.train()
|
||||
return x
|
||||
|
||||
@@ -271,6 +291,87 @@ class DiffusionTrainer:
|
||||
if task:
|
||||
task.close()
|
||||
|
||||
def plot_from_samples(self, samples, target):
|
||||
ci_99_upper = np.quantile(samples, 0.995, axis=0)
|
||||
ci_99_lower = np.quantile(samples, 0.005, axis=0)
|
||||
|
||||
ci_95_upper = np.quantile(samples, 0.975, axis=0)
|
||||
ci_95_lower = np.quantile(samples, 0.025, axis=0)
|
||||
|
||||
ci_90_upper = np.quantile(samples, 0.95, axis=0)
|
||||
ci_90_lower = np.quantile(samples, 0.05, axis=0)
|
||||
|
||||
ci_50_lower = np.quantile(samples, 0.25, axis=0)
|
||||
ci_50_upper = np.quantile(samples, 0.75, axis=0)
|
||||
|
||||
sns.set_theme()
|
||||
time_steps = np.arange(0, 96)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
ax.plot(
|
||||
time_steps,
|
||||
samples.mean(axis=0),
|
||||
label="Mean of NRV samples",
|
||||
linewidth=3,
|
||||
)
|
||||
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
|
||||
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_99_lower,
|
||||
ci_99_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="99% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_95_lower,
|
||||
ci_95_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="95% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_90_lower,
|
||||
ci_90_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="90% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_50_lower,
|
||||
ci_50_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="50% Interval",
|
||||
)
|
||||
|
||||
ax.plot(target, label="Real NRV", linewidth=3)
|
||||
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
|
||||
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
|
||||
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
|
||||
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
|
||||
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
|
||||
|
||||
ax.legend(
|
||||
handles=[
|
||||
ci_99_patch,
|
||||
ci_95_patch,
|
||||
ci_90_patch,
|
||||
ci_50_patch,
|
||||
ax.lines[0],
|
||||
ax.lines[1],
|
||||
]
|
||||
)
|
||||
|
||||
ax.set_ylim([-1500, 1500])
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch):
|
||||
for actual_idx, idx in sample_indices.items():
|
||||
features, target, _ = data_loader.dataset[idx]
|
||||
@@ -280,86 +381,69 @@ class DiffusionTrainer:
|
||||
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
samples = self.sample(self.model, 100, features).cpu().numpy()
|
||||
samples, intermediates = (
|
||||
self.sample(self.model, 100, features, True)
|
||||
)
|
||||
samples = samples.cpu().numpy()
|
||||
samples = self.data_processor.inverse_transform(samples)
|
||||
target = self.data_processor.inverse_transform(target)
|
||||
|
||||
ci_99_upper = np.quantile(samples, 0.995, axis=0)
|
||||
ci_99_lower = np.quantile(samples, 0.005, axis=0)
|
||||
# list to tensor intermediate samples
|
||||
intermediates = torch.stack(intermediates)
|
||||
|
||||
ci_95_upper = np.quantile(samples, 0.975, axis=0)
|
||||
ci_95_lower = np.quantile(samples, 0.025, axis=0)
|
||||
|
||||
ci_90_upper = np.quantile(samples, 0.95, axis=0)
|
||||
ci_90_lower = np.quantile(samples, 0.05, axis=0)
|
||||
|
||||
ci_50_lower = np.quantile(samples, 0.25, axis=0)
|
||||
ci_50_upper = np.quantile(samples, 0.75, axis=0)
|
||||
|
||||
sns.set_theme()
|
||||
time_steps = np.arange(0, 96)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(20, 10))
|
||||
ax.plot(
|
||||
time_steps,
|
||||
samples.mean(axis=0),
|
||||
label="Mean of NRV samples",
|
||||
linewidth=3,
|
||||
)
|
||||
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
|
||||
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_99_lower,
|
||||
ci_99_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="99% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_95_lower,
|
||||
ci_95_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="95% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_90_lower,
|
||||
ci_90_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="90% Interval",
|
||||
)
|
||||
ax.fill_between(
|
||||
time_steps,
|
||||
ci_50_lower,
|
||||
ci_50_upper,
|
||||
color="b",
|
||||
alpha=0.2,
|
||||
label="50% Interval",
|
||||
fig = self.plot_from_samples(samples, target)
|
||||
intermediate_fig1 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[0].cpu().numpy()), target
|
||||
)
|
||||
|
||||
ax.plot(target, label="Real NRV", linewidth=3)
|
||||
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
|
||||
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
|
||||
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
|
||||
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
|
||||
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
|
||||
|
||||
ax.legend(
|
||||
handles=[
|
||||
ci_99_patch,
|
||||
ci_95_patch,
|
||||
ci_90_patch,
|
||||
ci_50_patch,
|
||||
ax.lines[0],
|
||||
ax.lines[1],
|
||||
]
|
||||
intermediate_fig2 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[1].cpu().numpy()), target
|
||||
)
|
||||
|
||||
ax.set_ylim([-1500, 1500])
|
||||
intermediate_fig3 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[2].cpu().numpy()), target
|
||||
)
|
||||
|
||||
intermediate_fig4 = self.plot_from_samples(
|
||||
self.data_processor.inverse_transform(intermediates[3].cpu().numpy()), target
|
||||
)
|
||||
|
||||
|
||||
# report the intermediate figs to clearml
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 1",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig1,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 2",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig2,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 3",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig3,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title=f"Training Intermediates {actual_idx}" if training else f"Testing Intermediates {actual_idx}",
|
||||
series=f"Sample intermediate 4",
|
||||
iteration=epoch,
|
||||
figure=intermediate_fig4,
|
||||
report_image=True
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
task.get_logger().report_matplotlib_figure(
|
||||
title="Training" if training else "Testing",
|
||||
|
||||
@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
|
||||
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NrvForecast")
|
||||
task = clearml_helper.get_task(
|
||||
task_name="Diffusion Training: hidden_sizes=[1024, 1024, 1024], lr=0.0001, time_dim=8 + Load + PV + Wind + NP"
|
||||
task_name="Diffusion Training: hidden_sizes=[256, 256, 256], lr=0.0001, time_dim=8"
|
||||
)
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
@@ -19,19 +19,19 @@ from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
data_config = DataConfig()
|
||||
data_config.NRV_HISTORY = True
|
||||
|
||||
data_config.LOAD_HISTORY = True
|
||||
data_config.LOAD_FORECAST = True
|
||||
data_config.LOAD_HISTORY = False
|
||||
data_config.LOAD_FORECAST = False
|
||||
|
||||
data_config.PV_FORECAST = True
|
||||
data_config.PV_HISTORY = True
|
||||
data_config.PV_FORECAST = False
|
||||
data_config.PV_HISTORY = False
|
||||
|
||||
data_config.WIND_FORECAST = True
|
||||
data_config.WIND_HISTORY = True
|
||||
data_config.WIND_FORECAST = False
|
||||
data_config.WIND_HISTORY = False
|
||||
|
||||
data_config.QUARTER = True
|
||||
data_config.DAY_OF_WEEK = True
|
||||
data_config.QUARTER = False
|
||||
data_config.DAY_OF_WEEK = False
|
||||
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
data_config.NOMINAL_NET_POSITION = False
|
||||
|
||||
data_config = task.connect(data_config, name="data_features")
|
||||
|
||||
@@ -45,7 +45,7 @@ print("Input dim: ", inputDim)
|
||||
model_parameters = {
|
||||
"epochs": 15000,
|
||||
"learning_rate": 0.0001,
|
||||
"hidden_sizes": [1024, 1024, 1024],
|
||||
"hidden_sizes": [256, 256, 256],
|
||||
"time_dim": 8,
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from src.utils.clearml import ClearMLHelper
|
||||
|
||||
#### ClearML ####
|
||||
clearml_helper = ClearMLHelper(project_name="Thesis/NAQR: GRU")
|
||||
task = clearml_helper.get_task(task_name="NAQR: GRU (2 - 256) + Load")
|
||||
task = clearml_helper.get_task(task_name="NAQR: GRU (8 - 512) + Load + PV + Wind + NP")
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
from src.policies.PolicyEvaluator import PolicyEvaluator
|
||||
@@ -30,13 +30,13 @@ data_config.NRV_HISTORY = True
|
||||
data_config.LOAD_HISTORY = True
|
||||
data_config.LOAD_FORECAST = True
|
||||
|
||||
data_config.WIND_FORECAST = False
|
||||
data_config.WIND_HISTORY = False
|
||||
data_config.WIND_FORECAST = True
|
||||
data_config.WIND_HISTORY = True
|
||||
|
||||
data_config.PV_FORECAST = False
|
||||
data_config.PV_HISTORY = False
|
||||
data_config.PV_FORECAST = True
|
||||
data_config.PV_HISTORY = True
|
||||
|
||||
data_config.NOMINAL_NET_POSITION = False
|
||||
data_config.NOMINAL_NET_POSITION = True
|
||||
|
||||
|
||||
data_config = task.connect(data_config, name="data_features")
|
||||
|
||||
Reference in New Issue
Block a user