Added LSTM model
This commit is contained in:
45
src/models/lstm_model.py
Normal file
45
src/models/lstm_model.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
|
||||
class LSTMModel(torch.nn.Module):
|
||||
def __init__(self, inputSize, output_size, num_layers: int, hidden_size: int, dropout: float = 0.2):
|
||||
super(LSTMModel, self).__init__()
|
||||
self.inputSize = inputSize
|
||||
self.output_size = output_size
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout = dropout
|
||||
|
||||
self.lstm = torch.nn.LSTM(input_size=inputSize[-1], hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
|
||||
self.linear = torch.nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
# Forward pass through the LSTM layers
|
||||
_, (hidden_state, _) = self.lstm(x)
|
||||
|
||||
# Use the hidden state from the last time step for the output
|
||||
output = self.linear(hidden_state[-1])
|
||||
|
||||
return output
|
||||
|
||||
class GRUModel(torch.nn.Module):
|
||||
def __init__(self, inputSize, output_size, num_layers: int, hidden_size: int, dropout: float = 0.2):
|
||||
super(GRUModel, self).__init__()
|
||||
self.inputSize = inputSize
|
||||
self.output_size = output_size
|
||||
|
||||
self.num_layers = num_layers
|
||||
self.hidden_size = hidden_size
|
||||
self.dropout = dropout
|
||||
|
||||
self.gru = torch.nn.GRU(input_size=inputSize[-1], hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, batch_first=True)
|
||||
self.linear = torch.nn.Linear(hidden_size, output_size)
|
||||
|
||||
def forward(self, x):
|
||||
# Forward pass through the GRU layers
|
||||
_, hidden_state = self.gru(x)
|
||||
|
||||
# Use the hidden state from the last time step for the output
|
||||
output = self.linear(hidden_state[-1])
|
||||
|
||||
return output
|
||||
@@ -10,19 +10,20 @@ class TimeEmbedding(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
# Extract the last 'time_features' from the input
|
||||
time_feature = x[:, -1]
|
||||
time_feature = x[..., -1] # Use ellipsis to access the last dimension
|
||||
# convert to int
|
||||
time_feature = time_feature.int()
|
||||
# Embed these time features
|
||||
# print max value of time_feature
|
||||
if time_feature.max() > self.time_features:
|
||||
# print the row from x that includes the max value in the last column
|
||||
print(x[time_feature == time_feature.max()])
|
||||
print("time feature max value is greater than time features")
|
||||
|
||||
embedded_time = self.embedding(time_feature)
|
||||
# Concatenate the embedded features with the original input (minus the last 'time feature')
|
||||
return torch.cat((x[:, :-1], embedded_time), dim=1)
|
||||
return torch.cat((x[..., :-1], embedded_time), dim=-1) # Use -1 to specify the last dimension
|
||||
|
||||
|
||||
def output_dim(self, input_dim):
|
||||
return input_dim + self.embedding.embedding_dim - 1
|
||||
# Create a list from the input dimension
|
||||
input_dim_list = list(input_dim)
|
||||
# Modify the last dimension
|
||||
input_dim_list[-1] = input_dim_list[-1] - 1 + self.embedding.embedding_dim
|
||||
# Convert the list back to a torch.Size object
|
||||
output_dim = torch.Size(input_dim_list)
|
||||
return output_dim
|
||||
|
||||
Reference in New Issue
Block a user