Updated training script for GRU model
This commit is contained in:
@@ -10,7 +10,7 @@ from torch.nn import MSELoss, L1Loss
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from src.models.time_embedding_layer import TimeEmbedding
|
from src.models.time_embedding_layer import TimeEmbedding
|
||||||
from src.models.diffusion_model import SimpleDiffusionModel
|
from src.models.diffusion_model import GRUDiffusionModel, SimpleDiffusionModel
|
||||||
from src.trainers.diffusion_trainer import DiffusionTrainer
|
from src.trainers.diffusion_trainer import DiffusionTrainer
|
||||||
|
|
||||||
|
|
||||||
@@ -37,7 +37,7 @@ data_config.NOMINAL_NET_POSITION = True
|
|||||||
|
|
||||||
data_config = task.connect(data_config, name="data_features")
|
data_config = task.connect(data_config, name="data_features")
|
||||||
|
|
||||||
data_processor = DataProcessor(data_config, path="", lstm=False)
|
data_processor = DataProcessor(data_config, path="", lstm=True)
|
||||||
data_processor.set_batch_size(8192)
|
data_processor.set_batch_size(8192)
|
||||||
data_processor.set_full_day_skip(True)
|
data_processor.set_full_day_skip(True)
|
||||||
|
|
||||||
@@ -53,7 +53,8 @@ model_parameters = {
|
|||||||
model_parameters = task.connect(model_parameters, name="model_parameters")
|
model_parameters = task.connect(model_parameters, name="model_parameters")
|
||||||
|
|
||||||
#### Model ####
|
#### Model ####
|
||||||
model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"])
|
# model = SimpleDiffusionModel(96, model_parameters["hidden_sizes"], other_inputs_dim=inputDim[1], time_dim=model_parameters["time_dim"])
|
||||||
|
model = GRUDiffusionModel(96, [256, 256], other_inputs_dim=inputDim[2], time_dim=64, gru_hidden_size=128)
|
||||||
|
|
||||||
print("Starting training ...")
|
print("Starting training ...")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user