diff --git a/requirements.txt b/requirements.txt index 42c82fc..39b6977 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ torch torchvision pandas mediapipe -joblib \ No newline at end of file +joblib +tensorboard \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py index 1263764..4f8d35f 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -1,9 +1,12 @@ -import torch import json -from keypoint_extractor import KeypointExtractor from collections import OrderedDict -from identifiers import LANDMARKS + import numpy as np +import torch + +from identifiers import LANDMARKS +from keypoint_extractor import KeypointExtractor + class WLASLDataset(torch.utils.data.Dataset): def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None): diff --git a/src/identifiers.py b/src/identifiers.py index c1f8465..9d9e88e 100644 --- a/src/identifiers.py +++ b/src/identifiers.py @@ -1,38 +1,38 @@ LANDMARKS = { # Pose Landmarks "nose": 0, - "left_eye_inner": 1, + # "left_eye_inner": 1, "left_eye": 2, - "left_eye_outer": 3, - "right_eye_inner": 4, + # "left_eye_outer": 3, + # "right_eye_inner": 4, "right_eye": 5, - "right_eye_outer": 6, + # "right_eye_outer": 6, "left_ear": 7, "right_ear": 8, "mouth_left": 9, - "mouth_right": 10, + # "mouth_right": 10, "left_shoulder": 11, "right_shoulder": 12, "left_elbow": 13, "right_elbow": 14, "left_wrist": 15, "right_wrist": 16, - "left_pinky": 17, - "right_pinky": 18, - "left_index": 19, - "right_index": 20, - "left_thumb": 21, - "right_thumb": 22, - "left_hip": 23, - "right_hip": 24, - "left_knee": 25, - "right_knee": 26, - "left_ankle": 27, - "right_ankle": 28, - "left_heel": 29, - "right_heel": 30, - "left_foot_index": 31, - "right_foot_index": 32, + # "left_pinky": 17, + # "right_pinky": 18, + # "left_index": 19, + # "right_index": 20, + # "left_thumb": 21, + # "right_thumb": 22, + # "left_hip": 23, + # "right_hip": 24, + # "left_knee": 25, + # "right_knee": 26, + # "left_ankle": 27, + # "right_ankle": 28, + # "left_heel": 29, + # "right_heel": 30, + # "left_foot_index": 31, + # "right_foot_index": 32, # Left Hand Landmarks "left_wrist2": 33, diff --git a/src/model.py b/src/model.py index c2aa1db..6ee9cce 100644 --- a/src/model.py +++ b/src/model.py @@ -1,11 +1,11 @@ ### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data" import copy -import torch - -import torch.nn as nn from typing import Optional +import torch +import torch.nn as nn + def _get_clones(mod, n): return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)]) @@ -51,7 +51,7 @@ class SPOTER(nn.Module): self.row_embed = nn.Parameter(torch.rand(50, hidden_dim)) self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0)) self.class_query = nn.Parameter(torch.rand(1, hidden_dim)) - self.transformer = nn.Transformer(hidden_dim, 10, 6, 6) + self.transformer = nn.Transformer(hidden_dim, 9, 6, 6) self.linear_class = nn.Linear(hidden_dim, num_classes) # Deactivate the initial attention decoder mechanism @@ -61,7 +61,6 @@ class SPOTER(nn.Module): def forward(self, inputs): h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float() - h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1) res = self.linear_class(h) diff --git a/src/train.py b/src/train.py index 44c7f34..245f6c5 100644 --- a/src/train.py +++ b/src/train.py @@ -1,22 +1,23 @@ -import os import argparse -import random import logging -import torch - -import numpy as np -import torch.nn as nn -import torch.optim as optim -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -from torchvision import transforms -from torch.utils.data import DataLoader +import os +import random from pathlib import Path -from keypoint_extractor import KeypointExtractor -from identifiers import LANDMARKS -from model import SPOTER +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import transforms + from dataset import WLASLDataset +from identifiers import LANDMARKS +from keypoint_extractor import KeypointExtractor +from model import SPOTER + def train(): random.seed(379) @@ -31,7 +32,7 @@ def train(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - spoter_model = SPOTER(num_classes=100, hidden_dim=2*75) + spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2) spoter_model.train(True) spoter_model.to(device) @@ -47,7 +48,7 @@ def train(): train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train") train_loader = DataLoader(train_set, shuffle=True, generator=g) - + val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val") val_loader = DataLoader(val_set, shuffle=True, generator=g) @@ -56,53 +57,54 @@ def train(): train_acc, val_acc = 0, 0 - losses, train_accs, val_accs = [], [], [] lr_progress = [] top_train_acc, top_val_acc = 0, 0 checkpoint_index = 0 for epoch in range(100): + running_loss = 0.0 + pred_correct, pred_all = 0, 0 + # train for i, (inputs, labels) in enumerate(train_loader): inputs = inputs.squeeze(0).to(device) - labels = labels.to(device) + labels = labels.to(device, dtype=torch.long) optimizer.zero_grad() outputs = spoter_model(inputs).expand(1, -1, -1) loss = criterion(outputs[0], labels) loss.backward() optimizer.step() + running_loss += loss - _, predicted = torch.max(outputs.data, 1) - train_acc = (predicted == labels).sum().item() / labels.size(0) - - losses.append(loss.item()) - train_accs.append(train_acc) + if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]): + pred_correct += 1 + pred_all += 1 if i % 100 == 0: - print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}") + print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}") + + if scheduler: + scheduler.step(running_loss.item() / len(train_loader)) # validate with torch.no_grad(): for i, (inputs, labels) in enumerate(val_loader): - inputs = inputs.to(device) + inputs = inputs.squeeze(0).to(device) labels = labels.to(device) outputs = spoter_model(inputs) _, predicted = torch.max(outputs.data, 1) val_acc = (predicted == labels).sum().item() / labels.size(0) - val_accs.append(val_acc) - - scheduler.step(loss) # save checkpoint - if val_acc > top_val_acc: - top_val_acc = val_acc - top_train_acc = train_acc - checkpoint_index = epoch - torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") + # if val_acc > top_val_acc: + # top_val_acc = val_acc + # top_train_acc = train_acc + # checkpoint_index = epoch + # torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}") lr_progress.append(optimizer.param_groups[0]['lr'])