Added time as input feature
This commit is contained in:
28
src/models/time_embedding_layer.py
Normal file
28
src/models/time_embedding_layer.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
class TimeEmbedding(nn.Module):
|
||||
def __init__(self, time_features: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.time_features = time_features
|
||||
print(time_features)
|
||||
self.embedding = nn.Embedding(time_features, embedding_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# Extract the last 'time_features' from the input
|
||||
time_feature = x[:, -1]
|
||||
# convert to int
|
||||
time_feature = time_feature.int()
|
||||
# Embed these time features
|
||||
# print max value of time_feature
|
||||
if time_feature.max() > self.time_features:
|
||||
# print the row from x that includes the max value in the last column
|
||||
print(x[time_feature == time_feature.max()])
|
||||
print("time feature max value is greater than time features")
|
||||
|
||||
embedded_time = self.embedding(time_feature)
|
||||
# Concatenate the embedded features with the original input (minus the last 'time feature')
|
||||
return torch.cat((x[:, :-1], embedded_time), dim=1)
|
||||
|
||||
def output_dim(self, input_dim):
|
||||
return input_dim + self.embedding.embedding_dim - 1
|
||||
Reference in New Issue
Block a user