Added new training scripts

This commit is contained in:
Victor Mylle
2023-11-27 14:55:22 +00:00
parent 5e87165dbb
commit c1152ff96c
7 changed files with 37 additions and 36 deletions

View File

@@ -19,7 +19,6 @@ class AutoRegressiveTrainer(Trainer):
criterion: torch.nn.Module,
data_processor: DataProcessor,
device: torch.device,
clearml_helper: ClearMLHelper = None,
debug: bool = True,
):
super().__init__(
@@ -28,7 +27,6 @@ class AutoRegressiveTrainer(Trainer):
criterion=criterion,
data_processor=data_processor,
device=device,
clearml_helper=clearml_helper,
debug=debug,
)
self.model.output_size = 1

View File

@@ -10,12 +10,16 @@ import matplotlib.pyplot as plt
def sample_from_dist(quantiles, output_values):
# both to numpy
quantiles = quantiles.cpu().numpy()
# check if tensor:
if isinstance(quantiles, torch.Tensor):
quantiles = quantiles.cpu().numpy()
if isinstance(output_values, torch.Tensor):
output_values = output_values.cpu().numpy()
if isinstance(quantiles, list):
quantiles = np.array(quantiles)
reshaped_values = output_values.reshape(-1, len(quantiles))
uniform_random_numbers = np.random.uniform(0, 1, (reshaped_values.shape[0], 1000))
@@ -60,22 +64,18 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
data_processor: DataProcessor,
quantiles: list,
device: torch.device,
clearml_helper: ClearMLHelper = None,
debug: bool = True,
):
self.quantiles = quantiles
quantiles_tensor = torch.tensor(quantiles)
quantiles_tensor = quantiles_tensor.to(device)
criterion = PinballLoss(quantiles=quantiles_tensor)
criterion = PinballLoss(quantiles=quantiles)
super().__init__(
model=model,
optimizer=optimizer,
criterion=criterion,
data_processor=data_processor,
device=device,
clearml_helper=clearml_helper,
debug=debug,
)
@@ -252,7 +252,7 @@ class AutoRegressiveQuantileTrainer(AutoRegressiveTrainer):
def plot_quantile_percentages(
self, task, data_loader, train: bool = True, iteration: int = None
):
quantiles = self.quantiles.cpu().numpy()
quantiles = self.quantiles
total = 0
quantile_counter = {q: 0 for q in quantiles}

View File

@@ -1,3 +1,4 @@
from clearml import Task
import torch
from src.data.preprocessing import DataProcessor
from src.utils.clearml import ClearMLHelper
@@ -15,14 +16,12 @@ class Trainer:
criterion: torch.nn.Module,
data_processor: DataProcessor,
device: torch.device,
clearml_helper: ClearMLHelper = None,
debug: bool = True,
):
self.model = model
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.clearml_helper = clearml_helper
self.debug = debug
self.metrics_to_track = []
@@ -48,12 +47,9 @@ class Trainer:
else:
self.metrics_to_track.append(loss)
def init_clearml_task(self):
if not self.clearml_helper:
return None
task = self.clearml_helper.get_task(task_name="None")
def init_clearml_task(self, task):
if task is None:
return
# check if running remotely
@@ -77,15 +73,14 @@ class Trainer:
self.optimizer.name = self.optimizer.__class__.__name__
self.criterion.name = self.criterion.__class__.__name__
task.connect(self.optimizer, name="optimizer")
task.connect(self.criterion, name="criterion")
task.connect(self.data_processor, name="data_processor")
task.connect(self, name="trainer")
self.optimizer = task.connect(self.optimizer, name="optimizer")
self.criterion = task.connect(self.criterion, name="criterion")
self.data_processor = task.connect(self.data_processor, name="data_processor")
self = task.connect(self, name="trainer")
task.delete_parameter("trainer/quantiles")
task.connect(self.data_processor.data_config, name="data_features")
return task
def random_samples(self, train: bool = True, num_samples: int = 10):
train_loader, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size
@@ -99,7 +94,7 @@ class Trainer:
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
return indices
def train(self, epochs: int, remotely: bool = False):
def train(self, epochs: int, remotely: bool = False, task: Task = None):
try:
train_loader, test_loader = self.data_processor.get_dataloaders(
predict_sequence_length=self.model.output_size
@@ -108,7 +103,7 @@ class Trainer:
train_samples = self.random_samples(train=True)
test_samples = self.random_samples(train=False)
task = self.init_clearml_task()
self.init_clearml_task(task)
if remotely:
task.execute_remotely(queue_name="default", exit_process=True)