Fixed small summary with model architectures until now

This commit is contained in:
Victor Mylle
2023-11-30 21:53:35 +00:00
parent eba10c8f83
commit 120b6aa5bd
23 changed files with 402 additions and 185 deletions

View File

@@ -6,18 +6,20 @@ import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
from clearml.config import running_remotely
from torchinfo import summary
class Trainer:
def __init__(
self,
model: torch.nn.Module,
input_dim: tuple,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
data_processor: DataProcessor,
device: torch.device,
debug: bool = True,
):
self.input_dim = input_dim
self.model = model
self.optimizer = optimizer
self.criterion = criterion
@@ -70,6 +72,8 @@ class Trainer:
task.add_tags(self.optimizer.__class__.__name__)
task.add_tags(self.__class__.__name__)
task.set_configuration_object("model", str(summary(self.model, self.input_dim)))
self.optimizer.name = self.optimizer.__class__.__name__
self.criterion.name = self.criterion.__class__.__name__