from clearml import Task import torch import torch.nn as nn from src.policies.PolicyEvaluator import PolicyEvaluator from torchinfo import summary from src.losses.crps_metric import crps_from_samples from src.data.preprocessing import DataProcessor from src.models.diffusion_model import DiffusionModel import numpy as np import matplotlib.pyplot as plt import seaborn as sns import matplotlib.patches as mpatches def sample_diffusion( model: DiffusionModel, n: int, inputs: torch.tensor, noise_steps=1000, beta_start=1e-4, beta_end=0.02, ts_length=96, ): device = next(model.parameters()).device beta = torch.linspace(beta_start, beta_end, noise_steps).to(device) alpha = 1.0 - beta alpha_hat = torch.cumprod(alpha, dim=0) if len(inputs.shape) == 2: inputs = inputs.repeat(n, 1) elif len(inputs.shape) == 3: inputs = inputs.repeat(n, 1, 1) model.eval() with torch.no_grad(): x = torch.randn(inputs.shape[0], ts_length).to(device) for i in reversed(range(1, noise_steps)): t = (torch.ones(inputs.shape[0]) * i).long().to(device) predicted_noise = model(x, t, inputs) _alpha = alpha[t][:, None] _alpha_hat = alpha_hat[t][:, None] _beta = beta[t][:, None] if i > 1: noise = torch.randn_like(x) else: noise = torch.zeros_like(x) x = ( 1 / torch.sqrt(_alpha) * (x - ((1 - _alpha) / (torch.sqrt(1 - _alpha_hat))) * predicted_noise) + torch.sqrt(_beta) * noise ) x = torch.clamp(x, -1.0, 1.0) return x class DiffusionTrainer: def __init__( self, model: nn.Module, data_processor: DataProcessor, device: torch.device, policy_evaluator: PolicyEvaluator = None, ): self.model = model self.device = device self.noise_steps = 30 self.beta_start = 0.0001 self.beta_end = 0.02 self.ts_length = 96 self.data_processor = data_processor self.beta = torch.linspace(self.beta_start, self.beta_end, self.noise_steps).to( self.device ) self.alpha = 1.0 - self.beta self.alpha_hat = torch.cumprod(self.alpha, dim=0) self.best_score = None self.policy_evaluator = policy_evaluator self.prev_optimal_penalty = 0 def noise_time_series(self, x: torch.tensor, t: int): """Add noise to time series Args: x (torch.tensor): shape (batch_size, time_steps) t (int): index of time step """ sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None] sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])[:, None] noise = torch.randn_like(x) return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise def sample_timesteps(self, n: int): """Sample timesteps for noise Args: n (int): number of samples """ return torch.randint(low=1, high=self.noise_steps, size=(n,)) def sample(self, model: DiffusionModel, n: int, inputs: torch.tensor): x = sample_diffusion( model, n, inputs, self.noise_steps, self.beta_start, self.beta_end, self.ts_length, ) model.train() return x def random_samples(self, train: bool = True, num_samples: int = 10): train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=96 ) if train: loader = train_loader else: loader = test_loader # set seed np.random.seed(42) actual_indices = np.random.choice( loader.dataset.full_day_valid_indices, num_samples, replace=False ) indices = {} for i in actual_indices: indices[i] = loader.dataset.valid_indices.index(i) print(actual_indices) return indices def init_clearml_task(self, task): task.add_tags(self.model.__class__.__name__) task.add_tags(self.__class__.__name__) input_data = torch.randn(1024, 96).to(self.device) time_steps = torch.randn(1024).long().to(self.device) if self.data_processor.lstm: inputDim = self.data_processor.get_input_size() other_input_data = torch.randn( 1024, inputDim[1], self.model.other_inputs_dim ).to(self.device) else: other_input_data = torch.randn(1024, self.model.other_inputs_dim).to( self.device ) task.set_configuration_object( "model", str( summary( self.model, input_data=[input_data, time_steps, other_input_data] ) ), ) self.data_processor = task.connect(self.data_processor, name="data_processor") def train(self, epochs: int, learning_rate: float, task: Task = None): self.best_score = None optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate) criterion = nn.MSELoss() self.model.to(self.device) early_stopping = 0 best_crps = None if task: self.init_clearml_task(task) train_loader, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=self.ts_length ) train_sample_indices = self.random_samples(train=True, num_samples=5) test_sample_indices = self.random_samples(train=False, num_samples=5) for epoch in range(epochs): running_loss = 0.0 for i, k in enumerate(train_loader): time_series, base_pattern = k[1], k[0] time_series = time_series.to(self.device) base_pattern = base_pattern.to(self.device) t = self.sample_timesteps(time_series.shape[0]).to(self.device) x_t, noise = self.noise_time_series(time_series, t) predicted_noise = self.model(x_t, t, base_pattern) loss = criterion(predicted_noise, noise) running_loss += loss.item() optimizer.zero_grad() loss.backward() optimizer.step() running_loss /= len(train_loader.dataset) if epoch % 150 == 0 and epoch != 0: crps, _ = self.test(test_loader, epoch, task) if best_crps is None or crps < best_crps: best_crps = crps early_stopping = 0 else: early_stopping += 1 if early_stopping > 15: break if task: task.get_logger().report_scalar( title=criterion.__class__.__name__, series="train", iteration=epoch, value=loss.item(), ) if epoch % 150 == 0 and epoch != 0: self.debug_plots( task, True, train_loader, train_sample_indices, epoch ) self.debug_plots( task, False, test_loader, test_sample_indices, epoch ) # load the best model self.model = torch.load("checkpoint.pt") self.model.to(self.device) _, generated_sampels = self.test(test_loader, None, task) # self.policy_evaluator.plot_profits_table() optimal_penalty, profit, charge_cycles = ( self.policy_evaluator.optimize_penalty_for_target_charge_cycles( idx_samples=generated_sampels, test_loader=test_loader, initial_penalty=900, target_charge_cycles=283, learning_rate=1, max_iterations=50, tolerance=1, ) ) print( f"Optimal Penalty: {optimal_penalty}, Profit: {profit}, Charge Cycles: {charge_cycles}" ) task.get_logger().report_single_value( name="Optimal Penalty", value=optimal_penalty ) task.get_logger().report_single_value(name="Optimal Profit", value=profit) task.get_logger().report_single_value( name="Optimal Charge Cycles", value=charge_cycles ) if task: task.close() def debug_plots(self, task, training: bool, data_loader, sample_indices, epoch): for actual_idx, idx in sample_indices.items(): features, target, _ = data_loader.dataset[idx] features = features.to(self.device) features = features.unsqueeze(0) self.model.eval() with torch.no_grad(): samples = self.sample(self.model, 100, features).cpu().numpy() samples = self.data_processor.inverse_transform(samples) target = self.data_processor.inverse_transform(target) ci_99_upper = np.quantile(samples, 0.995, axis=0) ci_99_lower = np.quantile(samples, 0.005, axis=0) ci_95_upper = np.quantile(samples, 0.975, axis=0) ci_95_lower = np.quantile(samples, 0.025, axis=0) ci_90_upper = np.quantile(samples, 0.95, axis=0) ci_90_lower = np.quantile(samples, 0.05, axis=0) ci_50_lower = np.quantile(samples, 0.25, axis=0) ci_50_upper = np.quantile(samples, 0.75, axis=0) sns.set_theme() time_steps = np.arange(0, 96) fig, ax = plt.subplots(figsize=(20, 10)) ax.plot( time_steps, samples.mean(axis=0), label="Mean of NRV samples", linewidth=3, ) # ax.fill_between(time_steps, ci_lower, ci_upper, color='b', alpha=0.2, label='Full Interval') ax.fill_between( time_steps, ci_99_lower, ci_99_upper, color="b", alpha=0.2, label="99% Interval", ) ax.fill_between( time_steps, ci_95_lower, ci_95_upper, color="b", alpha=0.2, label="95% Interval", ) ax.fill_between( time_steps, ci_90_lower, ci_90_upper, color="b", alpha=0.2, label="90% Interval", ) ax.fill_between( time_steps, ci_50_lower, ci_50_upper, color="b", alpha=0.2, label="50% Interval", ) ax.plot(target, label="Real NRV", linewidth=3) # full_interval_patch = mpatches.Patch(color='b', alpha=0.2, label='Full Interval') ci_99_patch = mpatches.Patch(color="b", alpha=0.3, label="99% Interval") ci_95_patch = mpatches.Patch(color="b", alpha=0.4, label="95% Interval") ci_90_patch = mpatches.Patch(color="b", alpha=0.5, label="90% Interval") ci_50_patch = mpatches.Patch(color="b", alpha=0.6, label="50% Interval") ax.legend( handles=[ ci_99_patch, ci_95_patch, ci_90_patch, ci_50_patch, ax.lines[0], ax.lines[1], ] ) ax.set_ylim([-1500, 1500]) task.get_logger().report_matplotlib_figure( title="Training" if training else "Testing", series=f"Sample {actual_idx}", iteration=epoch, figure=fig, ) plt.close() # plot some samples for the nrv genearations (10 samples) (scale -1500 to 1500) fig, ax = plt.subplots(figsize=(20, 10)) for i in range(10): ax.plot(samples[i], label=f"Sample {i}") ax.plot(target, label="Real NRV", linewidth=3) ax.legend() ax.set_ylim([-1500, 1500]) task.get_logger().report_matplotlib_figure( title="Training Samples" if training else "Testing Samples", series=f"Sample {actual_idx} samples", iteration=epoch, figure=fig, report_interactive=False, ) plt.close() def test( self, data_loader: torch.utils.data.DataLoader, epoch: int, task: Task = None ): all_crps = [] generated_samples = {} for inputs, targets, idx_batch in data_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) number_of_samples = 100 sample = self.sample(self.model, number_of_samples, inputs) # reduce samples from (batch_size*number_of_samples, time_steps) to (batch_size, number_of_samples, time_steps) samples_batched = sample.reshape(inputs.shape[0], number_of_samples, 96) # add samples to generated_samples generated_samples[idx.item()] = (initial, samples) for i, (idx, samples) in enumerate(zip(idx_batch, samples_batched)): generated_samples[idx.item()] = ( self.data_processor.inverse_transform(inputs[i][:96]), self.data_processor.inverse_transform(samples), ) # calculate crps crps = crps_from_samples(samples_batched, targets) crps_mean = crps.mean(axis=1) # add all values from crps_mean to all_crps all_crps.extend(crps_mean.tolist()) all_crps = np.array(all_crps) mean_crps = all_crps.mean() if self.best_score is None or mean_crps < self.best_score: self.save_checkpoint(mean_crps, task, epoch) if task: task.get_logger().report_scalar( title="CRPS", series="test", value=mean_crps, iteration=epoch ) if self.policy_evaluator: _, test_loader = self.data_processor.get_dataloaders( predict_sequence_length=self.ts_length, full_day_skip=True ) optimal_penalty, profit, charge_cycles = ( self.policy_evaluator.optimize_penalty_for_target_charge_cycles( idx_samples=generated_samples, test_loader=test_loader, initial_penalty=self.prev_optimal_penalty, target_charge_cycles=283, learning_rate=1, max_iterations=50, tolerance=1, ) ) self.prev_optimal_penalty = optimal_penalty task.get_logger().report_scalar( title="Optimal Penalty", series="test", value=optimal_penalty, iteration=epoch, ) task.get_logger().report_scalar( title="Optimal Profit", series="test", value=profit, iteration=epoch ) task.get_logger().report_scalar( title="Optimal Charge Cycles", series="test", value=charge_cycles, iteration=epoch, ) return mean_crps, generated_samples def save_checkpoint(self, val_loss, task, iteration: int): torch.save(self.model, "checkpoint.pt") task.update_output_model( model_path="checkpoint.pt", iteration=iteration, auto_delete_file=False ) self.best_score = val_loss