Quarter embedding using trigonometry + more thesis writing
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TimeEmbedding(nn.Module):
|
||||
def __init__(self, time_features: int, embedding_dim: int):
|
||||
@@ -17,9 +19,10 @@ class TimeEmbedding(nn.Module):
|
||||
# Embed these time features
|
||||
embedded_time = self.embedding(time_feature)
|
||||
# Concatenate the embedded features with the original input (minus the last 'time feature')
|
||||
return torch.cat((x[..., :-1], embedded_time), dim=-1) # Use -1 to specify the last dimension
|
||||
return torch.cat(
|
||||
(x[..., :-1], embedded_time), dim=-1
|
||||
) # Use -1 to specify the last dimension
|
||||
|
||||
|
||||
def output_dim(self, input_dim):
|
||||
if self.time_features == 0:
|
||||
return input_dim
|
||||
@@ -30,3 +33,32 @@ class TimeEmbedding(nn.Module):
|
||||
# Convert the list back to a torch.Size object
|
||||
output_dim = torch.Size(input_dim_list)
|
||||
return output_dim
|
||||
|
||||
|
||||
class TrigonometricTimeEmbedding(nn.Module):
|
||||
def __init__(self, time_features: int):
|
||||
super().__init__()
|
||||
self.time_features = time_features
|
||||
|
||||
def forward(self, x):
|
||||
if self.time_features == 0:
|
||||
return x
|
||||
time_feature = x[..., -1] # Use ellipsis to access the last dimension
|
||||
time_feature = time_feature.int()
|
||||
# Calculate the sine and cosine of the time feature
|
||||
sin_time = torch.sin(2 * np.pi * time_feature.float() / self.time_features)
|
||||
cos_time = torch.cos(2 * np.pi * time_feature.float() / self.time_features)
|
||||
# Stack the sine and cosine features
|
||||
time_embedding = torch.stack((sin_time, cos_time), dim=-1)
|
||||
# Concatenate the embedded features with the original input (minus the last 'time feature')
|
||||
return torch.cat(
|
||||
(x[..., :-1], time_embedding), dim=-1
|
||||
) # Use -1 to specify the last dimension
|
||||
|
||||
def output_dim(self, input_dim):
|
||||
if self.time_features == 0:
|
||||
return input_dim
|
||||
input_dim_list = list(input_dim)
|
||||
input_dim_list[-1] = input_dim_list[-1] - 1 + 2
|
||||
output_dim = torch.Size(input_dim_list)
|
||||
return output_dim
|
||||
|
||||
Reference in New Issue
Block a user