475 lines
16 KiB
Python
475 lines
16 KiB
Python
from clearml import Task
|
|
import torch
|
|
import torch.nn as nn
|
|
from src.policies.PolicyEvaluator import PolicyEvaluator
|
|
from torchinfo import summary
|
|
from src.losses.crps_metric import crps_from_samples
|
|
from src.data.preprocessing import DataProcessor
|
|
|
|
from src.models.diffusion_model import DiffusionModel
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
import matplotlib.patches as mpatches
|
|
|
|
|
|
def sample_diffusion(
|
|
model: DiffusionModel,
|
|
n: int,
|
|
inputs: torch.tensor,
|
|
noise_steps=1000,
|
|
beta_start=1e-4,
|
|
beta_end=0.02,
|
|
ts_length=96,
|
|
):
|
|
device = next(model.parameters()).device
|
|
beta = torch.linspace(beta_start, beta_end, noise_steps).to(device)
|
|
alpha = 1.0 - beta
|
|
alpha_hat = torch.cumprod(alpha, dim=0)
|
|
|
|
if len(inputs.shape) == 2:
|
|
inputs = inputs.repeat(n, 1)
|
|
elif len(inputs.shape) == 3:
|
|
inputs = inputs.repeat(n, 1, 1)
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
x = torch.randn(inputs.shape[0], ts_length).to(device)
|
|
for i in reversed(range(1, noise_steps)):
|
|
t = (torch.ones(inputs.shape[0]) * i).long().to(device)
|
|
predicted_noise = model(x, t, inputs)
|
|
_alpha = alpha[t][:, None]
|
|
_alpha_hat = alpha_hat[t][:, None]
|
|
_beta = beta[t][:, None]
|
|
|
|
if i > 1:
|
|
noise = torch.randn_like(x)
|
|
else:
|
|
noise = torch.zeros_like(x)
|
|
|
|
x = (
|
|
1
|
|
/ torch.sqrt(_alpha)
|
|
* (x - ((1 - _alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise)
|
|
+ torch.sqrt(_beta) * noise
|
|
)
|
|
x = torch.clamp(x, -1.0, 1.0)
|
|
return x
|
|
|
|
|
|
class DiffusionTrainer:
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
data_processor: DataProcessor,
|
|
device: torch.device,
|
|
policy_evaluator: PolicyEvaluator = None,
|
|
):
|
|
self.model = model
|
|
self.device = device
|
|
|
|
self.noise_steps = 30
|
|
self.beta_start = 0.0001
|
|
self.beta_end = 0.02
|
|
self.ts_length = 96
|
|
|
|
self.data_processor = data_processor
|
|
|
|
self.beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps).to(
|
|
self.device
|
|
)
|
|
self.alpha = 1.0 - self.beta
|
|
self.alpha_hat = torch.cumprod(self.alpha, dim=0)
|
|
|
|
self.best_score = None
|
|
self.policy_evaluator = policy_evaluator
|
|
|
|
self.prev_optimal_penalty = 0
|
|
|
|
def noise_time_series(self, x: torch.tensor, t: int):
|
|
"""Add noise to time series
|
|
Args:
|
|
x (torch.tensor): shape (batch_size, time_steps)
|
|
t (int): index of time step
|
|
"""
|
|
sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None]
|
|
sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])[:, None]
|
|
noise = torch.randn_like(x)
|
|
return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise
|
|
|
|
def sample_timesteps(self, n: int):
|
|
"""Sample timesteps for noise
|
|
Args:
|
|
n (int): number of samples
|
|
"""
|
|
return torch.randint(low=1, high=self.noise_steps, size=(n,))
|
|
|
|
def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor):
|
|
x = sample_diffusion(
|
|
model,
|
|
n,
|
|
inputs,
|
|
self.noise_steps,
|
|
self.beta_start,
|
|
self.beta_end,
|
|
self.ts_length,
|
|
)
|
|
model.train()
|
|
return x
|
|
|
|
def random_samples(self, train: bool = True, num_samples: int = 10):
|
|
train_loader, test_loader = self.data_processor.get_dataloaders(
|
|
predict_sequence_length=96
|
|
)
|
|
|
|
if train:
|
|
loader = train_loader
|
|
else:
|
|
loader = test_loader
|
|
|
|
# set seed
|
|
np.random.seed(42)
|
|
|
|
actual_indices = np.random.choice(
|
|
loader.dataset.full_day_valid_indices, num_samples, replace=False
|
|
)
|
|
indices = {}
|
|
for i in actual_indices:
|
|
indices[i] = loader.dataset.valid_indices.index(i)
|
|
|
|
print(actual_indices)
|
|
|
|
return indices
|
|
|
|
def init_clearml_task(self, task):
|
|
task.add_tags(self.model.__class__.__name__)
|
|
task.add_tags(self.__class__.__name__)
|
|
|
|
input_data = torch.randn(1024, 96).to(self.device)
|
|
time_steps = torch.randn(1024).long().to(self.device)
|
|
|
|
if self.data_processor.lstm:
|
|
inputDim = self.data_processor.get_input_size()
|
|
other_input_data = torch.randn(
|
|
1024, inputDim[1], self.model.other_inputs_dim
|
|
).to(self.device)
|
|
else:
|
|
other_input_data = torch.randn(1024, self.model.other_inputs_dim).to(
|
|
self.device
|
|
)
|
|
task.set_configuration_object(
|
|
"model",
|
|
str(
|
|
summary(
|
|
self.model, input_data=[input_data, time_steps, other_input_data]
|
|
)
|
|
),
|
|
)
|
|
|
|
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
|
|
|
def train(self, epochs: int, learning_rate: float, task: Task = None):
|
|
self.best_score = None
|
|
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
|
|
criterion = nn.MSELoss()
|
|
self.model.to(self.device)
|
|
|
|
early_stopping = 0
|
|
best_crps = None
|
|
|
|
if task:
|
|
self.init_clearml_task(task)
|
|
|
|
train_loader, test_loader = self.data_processor.get_dataloaders(
|
|
predict_sequence_length=self.ts_length
|
|
)
|
|
|
|
train_sample_indices = self.random_samples(train=True, num_samples=5)
|
|
test_sample_indices = self.random_samples(train=False, num_samples=5)
|
|
|
|
for epoch in range(epochs):
|
|
running_loss = 0.0
|
|
for i, k in enumerate(train_loader):
|
|
time_series, base_pattern = k[1], k[0]
|
|
time_series = time_series.to(self.device)
|
|
base_pattern = base_pattern.to(self.device)
|
|
|
|
t = self.sample_timesteps(time_series.shape[0]).to(self.device)
|
|
x_t, noise = self.noise_time_series(time_series, t)
|
|
predicted_noise = self.model(x_t, t, base_pattern)
|
|
loss = criterion(predicted_noise, noise)
|
|
|
|
running_loss += loss.item()
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss /= len(train_loader.dataset)
|
|
|
|
if epoch % 150 == 0 and epoch != 0:
|
|
crps, _ = self.test(test_loader, epoch, task)
|
|
|
|
if best_crps is None or crps < best_crps:
|
|
best_crps = crps
|
|
early_stopping = 0
|
|
else:
|
|
early_stopping += 1
|
|
|
|
if early_stopping > 15:
|
|
break
|
|
|
|
if task:
|
|
task.get_logger().report_scalar(
|
|
title=criterion.__class__.__name__,
|
|
series="train",
|
|
iteration=epoch,
|
|
value=loss.item(),
|
|
)
|
|
|
|
if epoch % 150 == 0 and epoch != 0:
|
|
self.debug_plots(
|
|
task, True, train_loader, train_sample_indices, epoch
|
|
)
|
|
self.debug_plots(
|
|
task, False, test_loader, test_sample_indices, epoch
|
|
)
|
|
|
|
# load the best model
|
|
self.model = torch.load("checkpoint.pt")
|
|
self.model.to(self.device)
|
|
|
|
_, generated_sampels = self.test(test_loader, None, task)
|
|
# self.policy_evaluator.plot_profits_table()
|
|
|
|
optimal_penalty, profit, charge_cycles = (
|
|
self.policy_evaluator.optimize_penalty_for_target_charge_cycles(
|
|
idx_samples=generated_sampels,
|
|
test_loader=test_loader,
|
|
initial_penalty=900,
|
|
target_charge_cycles=283,
|
|
learning_rate=1,
|
|
max_iterations=50,
|
|
tolerance=1,
|
|
)
|
|
)
|
|
|
|
print(
|
|
f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}"
|
|
)
|
|
|
|
task.get_logger().report_single_value(
|
|
name="Optimal Penalty", value=optimal_penalty
|
|
)
|
|
task.get_logger().report_single_value(name="Optimal Profit", value=profit)
|
|
task.get_logger().report_single_value(
|
|
name="Optimal Charge Cycles", value=charge_cycles
|
|
)
|
|
|
|
if task:
|
|
task.close()
|
|
|
|
def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch):
|
|
for actual_idx, idx in sample_indices.items():
|
|
features, target, _ = data_loader.dataset[idx]
|
|
|
|
features = features.to(self.device)
|
|
features = features.unsqueeze(0)
|
|
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
samples = self.sample(self.model, 100, features).cpu().numpy()
|
|
samples = self.data_processor.inverse_transform(samples)
|
|
target = self.data_processor.inverse_transform(target)
|
|
|
|
ci_99_upper = np.quantile(samples, 0.995, axis=0)
|
|
ci_99_lower = np.quantile(samples, 0.005, axis=0)
|
|
|
|
ci_95_upper = np.quantile(samples, 0.975, axis=0)
|
|
ci_95_lower = np.quantile(samples, 0.025, axis=0)
|
|
|
|
ci_90_upper = np.quantile(samples, 0.95, axis=0)
|
|
ci_90_lower = np.quantile(samples, 0.05, axis=0)
|
|
|
|
ci_50_lower = np.quantile(samples, 0.25, axis=0)
|
|
ci_50_upper = np.quantile(samples, 0.75, axis=0)
|
|
|
|
sns.set_theme()
|
|
time_steps = np.arange(0, 96)
|
|
|
|
fig, ax = plt.subplots(figsize=(20, 10))
|
|
ax.plot(
|
|
time_steps,
|
|
samples.mean(axis=0),
|
|
label="Mean of NRV samples",
|
|
linewidth=3,
|
|
)
|
|
# ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval')
|
|
|
|
ax.fill_between(
|
|
time_steps,
|
|
ci_99_lower,
|
|
ci_99_upper,
|
|
color="b",
|
|
alpha=0.2,
|
|
label="99% Interval",
|
|
)
|
|
ax.fill_between(
|
|
time_steps,
|
|
ci_95_lower,
|
|
ci_95_upper,
|
|
color="b",
|
|
alpha=0.2,
|
|
label="95% Interval",
|
|
)
|
|
ax.fill_between(
|
|
time_steps,
|
|
ci_90_lower,
|
|
ci_90_upper,
|
|
color="b",
|
|
alpha=0.2,
|
|
label="90% Interval",
|
|
)
|
|
ax.fill_between(
|
|
time_steps,
|
|
ci_50_lower,
|
|
ci_50_upper,
|
|
color="b",
|
|
alpha=0.2,
|
|
label="50% Interval",
|
|
)
|
|
|
|
ax.plot(target, label="Real NRV", linewidth=3)
|
|
# full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval')
|
|
ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval")
|
|
ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval")
|
|
ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval")
|
|
ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval")
|
|
|
|
ax.legend(
|
|
handles=[
|
|
ci_99_patch,
|
|
ci_95_patch,
|
|
ci_90_patch,
|
|
ci_50_patch,
|
|
ax.lines[0],
|
|
ax.lines[1],
|
|
]
|
|
)
|
|
|
|
ax.set_ylim([-1500, 1500])
|
|
|
|
task.get_logger().report_matplotlib_figure(
|
|
title="Training" if training else "Testing",
|
|
series=f"Sample {actual_idx}",
|
|
iteration=epoch,
|
|
figure=fig,
|
|
)
|
|
|
|
plt.close()
|
|
|
|
# plot some samples for the nrv genearations (10 samples) (scale -1500 to 1500)
|
|
fig, ax = plt.subplots(figsize=(20, 10))
|
|
for i in range(10):
|
|
ax.plot(samples[i], label=f"Sample {i}")
|
|
|
|
ax.plot(target, label="Real NRV", linewidth=3)
|
|
ax.legend()
|
|
ax.set_ylim([-1500, 1500])
|
|
|
|
task.get_logger().report_matplotlib_figure(
|
|
title="Training Samples" if training else "Testing Samples",
|
|
series=f"Sample {actual_idx} samples",
|
|
iteration=epoch,
|
|
figure=fig,
|
|
report_interactive=False,
|
|
)
|
|
|
|
plt.close()
|
|
|
|
def test(
|
|
self, data_loader: torch.utils.data.DataLoader, epoch: int, task: Task = None
|
|
):
|
|
all_crps = []
|
|
generated_samples = {}
|
|
for inputs, targets, idx_batch in data_loader:
|
|
inputs, targets = inputs.to(self.device), targets.to(self.device)
|
|
|
|
number_of_samples = 100
|
|
sample = self.sample(self.model, number_of_samples, inputs)
|
|
|
|
# reduce samples from (batch_size*number_of_samples, time_steps) to (batch_size, number_of_samples, time_steps)
|
|
samples_batched = sample.reshape(inputs.shape[0], number_of_samples, 96)
|
|
|
|
# add samples to generated_samples generated_samples[idx.item()] = (initial, samples)
|
|
for i, (idx, samples) in enumerate(zip(idx_batch, samples_batched)):
|
|
generated_samples[idx.item()] = (
|
|
self.data_processor.inverse_transform(inputs[i][:96]),
|
|
self.data_processor.inverse_transform(samples),
|
|
)
|
|
|
|
# calculate crps
|
|
crps = crps_from_samples(samples_batched, targets)
|
|
crps_mean = crps.mean(axis=1)
|
|
|
|
# add all values from crps_mean to all_crps
|
|
all_crps.extend(crps_mean.tolist())
|
|
|
|
all_crps = np.array(all_crps)
|
|
mean_crps = all_crps.mean()
|
|
|
|
if self.best_score is None or mean_crps < self.best_score:
|
|
self.save_checkpoint(mean_crps, task, epoch)
|
|
|
|
if task:
|
|
task.get_logger().report_scalar(
|
|
title="CRPS", series="test", value=mean_crps, iteration=epoch
|
|
)
|
|
|
|
if self.policy_evaluator:
|
|
_, test_loader = self.data_processor.get_dataloaders(
|
|
predict_sequence_length=self.ts_length, full_day_skip=True
|
|
)
|
|
|
|
optimal_penalty, profit, charge_cycles = (
|
|
self.policy_evaluator.optimize_penalty_for_target_charge_cycles(
|
|
idx_samples=generated_samples,
|
|
test_loader=test_loader,
|
|
initial_penalty=self.prev_optimal_penalty,
|
|
target_charge_cycles=283,
|
|
learning_rate=1,
|
|
max_iterations=50,
|
|
tolerance=1,
|
|
)
|
|
)
|
|
|
|
self.prev_optimal_penalty = optimal_penalty
|
|
|
|
task.get_logger().report_scalar(
|
|
title="Optimal Penalty",
|
|
series="test",
|
|
value=optimal_penalty,
|
|
iteration=epoch,
|
|
)
|
|
|
|
task.get_logger().report_scalar(
|
|
title="Optimal Profit", series="test", value=profit, iteration=epoch
|
|
)
|
|
|
|
task.get_logger().report_scalar(
|
|
title="Optimal Charge Cycles",
|
|
series="test",
|
|
value=charge_cycles,
|
|
iteration=epoch,
|
|
)
|
|
|
|
return mean_crps, generated_samples
|
|
|
|
def save_checkpoint(self, val_loss, task, iteration: int):
|
|
torch.save(self.model, "checkpoint.pt")
|
|
task.update_output_model(
|
|
model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False
|
|
)
|
|
self.best_score = val_loss
|