Fixed some bugs in the training loop, still no good results

This commit is contained in:
2023-02-23 12:06:26 +00:00
parent 98f29f683e
commit 97ede38e3a
5 changed files with 68 additions and 63 deletions

View File

@@ -1,11 +1,11 @@
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
import copy
import torch
import torch.nn as nn
from typing import Optional
import torch
import torch.nn as nn
def _get_clones(mod, n):
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
@@ -51,7 +51,7 @@ class SPOTER(nn.Module):
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
self.transformer = nn.Transformer(hidden_dim, 10, 6, 6)
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
self.linear_class = nn.Linear(hidden_dim, num_classes)
# Deactivate the initial attention decoder mechanism
@@ -61,7 +61,6 @@ class SPOTER(nn.Module):
def forward(self, inputs):
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
res = self.linear_class(h)