Fixed small summary with model architectures until now
This commit is contained in:
@@ -6,18 +6,20 @@ import plotly.graph_objects as go
|
||||
import numpy as np
|
||||
from plotly.subplots import make_subplots
|
||||
from clearml.config import running_remotely
|
||||
|
||||
from torchinfo import summary
|
||||
|
||||
class Trainer:
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
input_dim: tuple,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: torch.nn.Module,
|
||||
data_processor: DataProcessor,
|
||||
device: torch.device,
|
||||
debug: bool = True,
|
||||
):
|
||||
self.input_dim = input_dim
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.criterion = criterion
|
||||
@@ -70,6 +72,8 @@ class Trainer:
|
||||
task.add_tags(self.optimizer.__class__.__name__)
|
||||
task.add_tags(self.__class__.__name__)
|
||||
|
||||
task.set_configuration_object("model", str(summary(self.model, self.input_dim)))
|
||||
|
||||
self.optimizer.name = self.optimizer.__class__.__name__
|
||||
self.criterion.name = self.criterion.__class__.__name__
|
||||
|
||||
|
||||
Reference in New Issue
Block a user