Added time as input feature

This commit is contained in:
Victor Mylle
2023-11-26 18:43:03 +00:00
parent a2c9427d16
commit 2f40f41786
9 changed files with 168 additions and 138 deletions

View File

@@ -0,0 +1,28 @@
from torch import nn
import torch
class TimeEmbedding(nn.Module):
def __init__(self, time_features: int, embedding_dim: int):
super().__init__()
self.time_features = time_features
print(time_features)
self.embedding = nn.Embedding(time_features, embedding_dim)
def forward(self, x):
# Extract the last 'time_features' from the input
time_feature = x[:, -1]
# convert to int
time_feature = time_feature.int()
# Embed these time features
# print max value of time_feature
if time_feature.max() > self.time_features:
# print the row from x that includes the max value in the last column
print(x[time_feature == time_feature.max()])
print("time feature max value is greater than time features")
embedded_time = self.embedding(time_feature)
# Concatenate the embedded features with the original input (minus the last 'time feature')
return torch.cat((x[:, :-1], embedded_time), dim=1)
def output_dim(self, input_dim):
return input_dim + self.embedding.embedding_dim - 1