Added new training scripts
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from clearml import Task
|
||||
import torch
|
||||
from src.data.preprocessing import DataProcessor
|
||||
from src.utils.clearml import ClearMLHelper
|
||||
@@ -15,14 +16,12 @@ class Trainer:
|
||||
criterion: torch.nn.Module,
|
||||
data_processor: DataProcessor,
|
||||
device: torch.device,
|
||||
clearml_helper: ClearMLHelper = None,
|
||||
debug: bool = True,
|
||||
):
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
self.device = device
|
||||
self.clearml_helper = clearml_helper
|
||||
self.debug = debug
|
||||
|
||||
self.metrics_to_track = []
|
||||
@@ -48,12 +47,9 @@ class Trainer:
|
||||
else:
|
||||
self.metrics_to_track.append(loss)
|
||||
|
||||
def init_clearml_task(self):
|
||||
if not self.clearml_helper:
|
||||
return None
|
||||
|
||||
|
||||
task = self.clearml_helper.get_task(task_name="None")
|
||||
def init_clearml_task(self, task):
|
||||
if task is None:
|
||||
return
|
||||
|
||||
# check if running remotely
|
||||
|
||||
@@ -77,15 +73,14 @@ class Trainer:
|
||||
self.optimizer.name = self.optimizer.__class__.__name__
|
||||
self.criterion.name = self.criterion.__class__.__name__
|
||||
|
||||
task.connect(self.optimizer, name="optimizer")
|
||||
task.connect(self.criterion, name="criterion")
|
||||
task.connect(self.data_processor, name="data_processor")
|
||||
task.connect(self, name="trainer")
|
||||
self.optimizer = task.connect(self.optimizer, name="optimizer")
|
||||
self.criterion = task.connect(self.criterion, name="criterion")
|
||||
self.data_processor = task.connect(self.data_processor, name="data_processor")
|
||||
self = task.connect(self, name="trainer")
|
||||
|
||||
task.delete_parameter("trainer/quantiles")
|
||||
task.connect(self.data_processor.data_config, name="data_features")
|
||||
|
||||
return task
|
||||
|
||||
def random_samples(self, train: bool = True, num_samples: int = 10):
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
@@ -99,7 +94,7 @@ class Trainer:
|
||||
indices = np.random.randint(0, len(loader.dataset) - 1, size=num_samples)
|
||||
return indices
|
||||
|
||||
def train(self, epochs: int, remotely: bool = False):
|
||||
def train(self, epochs: int, remotely: bool = False, task: Task = None):
|
||||
try:
|
||||
train_loader, test_loader = self.data_processor.get_dataloaders(
|
||||
predict_sequence_length=self.model.output_size
|
||||
@@ -108,7 +103,7 @@ class Trainer:
|
||||
train_samples = self.random_samples(train=True)
|
||||
test_samples = self.random_samples(train=False)
|
||||
|
||||
task = self.init_clearml_task()
|
||||
self.init_clearml_task(task)
|
||||
|
||||
if remotely:
|
||||
task.execute_remotely(queue_name="default", exit_process=True)
|
||||
|
||||
Reference in New Issue
Block a user