Fixed some bugs in the training loop, still no good results
This commit is contained in:
@@ -3,3 +3,4 @@ torchvision
|
|||||||
pandas
|
pandas
|
||||||
mediapipe
|
mediapipe
|
||||||
joblib
|
joblib
|
||||||
|
tensorboard
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
import torch
|
|
||||||
import json
|
import json
|
||||||
from keypoint_extractor import KeypointExtractor
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from identifiers import LANDMARKS
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from identifiers import LANDMARKS
|
||||||
|
from keypoint_extractor import KeypointExtractor
|
||||||
|
|
||||||
|
|
||||||
class WLASLDataset(torch.utils.data.Dataset):
|
class WLASLDataset(torch.utils.data.Dataset):
|
||||||
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
|
def __init__(self, json_file: str, missing: str, keypoint_extractor: KeypointExtractor, subset:str="train", keypoints_identifier: dict = None, transform=None):
|
||||||
|
|||||||
@@ -1,38 +1,38 @@
|
|||||||
LANDMARKS = {
|
LANDMARKS = {
|
||||||
# Pose Landmarks
|
# Pose Landmarks
|
||||||
"nose": 0,
|
"nose": 0,
|
||||||
"left_eye_inner": 1,
|
# "left_eye_inner": 1,
|
||||||
"left_eye": 2,
|
"left_eye": 2,
|
||||||
"left_eye_outer": 3,
|
# "left_eye_outer": 3,
|
||||||
"right_eye_inner": 4,
|
# "right_eye_inner": 4,
|
||||||
"right_eye": 5,
|
"right_eye": 5,
|
||||||
"right_eye_outer": 6,
|
# "right_eye_outer": 6,
|
||||||
"left_ear": 7,
|
"left_ear": 7,
|
||||||
"right_ear": 8,
|
"right_ear": 8,
|
||||||
"mouth_left": 9,
|
"mouth_left": 9,
|
||||||
"mouth_right": 10,
|
# "mouth_right": 10,
|
||||||
"left_shoulder": 11,
|
"left_shoulder": 11,
|
||||||
"right_shoulder": 12,
|
"right_shoulder": 12,
|
||||||
"left_elbow": 13,
|
"left_elbow": 13,
|
||||||
"right_elbow": 14,
|
"right_elbow": 14,
|
||||||
"left_wrist": 15,
|
"left_wrist": 15,
|
||||||
"right_wrist": 16,
|
"right_wrist": 16,
|
||||||
"left_pinky": 17,
|
# "left_pinky": 17,
|
||||||
"right_pinky": 18,
|
# "right_pinky": 18,
|
||||||
"left_index": 19,
|
# "left_index": 19,
|
||||||
"right_index": 20,
|
# "right_index": 20,
|
||||||
"left_thumb": 21,
|
# "left_thumb": 21,
|
||||||
"right_thumb": 22,
|
# "right_thumb": 22,
|
||||||
"left_hip": 23,
|
# "left_hip": 23,
|
||||||
"right_hip": 24,
|
# "right_hip": 24,
|
||||||
"left_knee": 25,
|
# "left_knee": 25,
|
||||||
"right_knee": 26,
|
# "right_knee": 26,
|
||||||
"left_ankle": 27,
|
# "left_ankle": 27,
|
||||||
"right_ankle": 28,
|
# "right_ankle": 28,
|
||||||
"left_heel": 29,
|
# "left_heel": 29,
|
||||||
"right_heel": 30,
|
# "right_heel": 30,
|
||||||
"left_foot_index": 31,
|
# "left_foot_index": 31,
|
||||||
"right_foot_index": 32,
|
# "right_foot_index": 32,
|
||||||
|
|
||||||
# Left Hand Landmarks
|
# Left Hand Landmarks
|
||||||
"left_wrist2": 33,
|
"left_wrist2": 33,
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
|
### SPOTER model implementation from the paper "SPOTER: Sign Pose-based Transformer for Sign Language Recognition from Sequence of Skeletal Data"
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import torch
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(mod, n):
|
def _get_clones(mod, n):
|
||||||
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
|
return nn.ModuleList([copy.deepcopy(mod) for _ in range(n)])
|
||||||
@@ -51,7 +51,7 @@ class SPOTER(nn.Module):
|
|||||||
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
self.row_embed = nn.Parameter(torch.rand(50, hidden_dim))
|
||||||
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
self.pos = nn.Parameter(torch.cat([self.row_embed[0].unsqueeze(0).repeat(1, 1, 1)], dim=-1).flatten(0, 1).unsqueeze(0))
|
||||||
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
self.class_query = nn.Parameter(torch.rand(1, hidden_dim))
|
||||||
self.transformer = nn.Transformer(hidden_dim, 10, 6, 6)
|
self.transformer = nn.Transformer(hidden_dim, 9, 6, 6)
|
||||||
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
self.linear_class = nn.Linear(hidden_dim, num_classes)
|
||||||
|
|
||||||
# Deactivate the initial attention decoder mechanism
|
# Deactivate the initial attention decoder mechanism
|
||||||
@@ -61,7 +61,6 @@ class SPOTER(nn.Module):
|
|||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
h = torch.unsqueeze(inputs.flatten(start_dim=1), 1).float()
|
||||||
|
|
||||||
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
h = self.transformer(self.pos + h, self.class_query.unsqueeze(0)).transpose(0, 1)
|
||||||
res = self.linear_class(h)
|
res = self.linear_class(h)
|
||||||
|
|
||||||
|
|||||||
66
src/train.py
66
src/train.py
@@ -1,22 +1,23 @@
|
|||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import os
|
||||||
|
import random
|
||||||
import numpy as np
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import matplotlib.ticker as ticker
|
|
||||||
from torchvision import transforms
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from keypoint_extractor import KeypointExtractor
|
import matplotlib.pyplot as plt
|
||||||
from identifiers import LANDMARKS
|
import matplotlib.ticker as ticker
|
||||||
from model import SPOTER
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
from dataset import WLASLDataset
|
from dataset import WLASLDataset
|
||||||
|
from identifiers import LANDMARKS
|
||||||
|
from keypoint_extractor import KeypointExtractor
|
||||||
|
from model import SPOTER
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
random.seed(379)
|
random.seed(379)
|
||||||
@@ -31,7 +32,7 @@ def train():
|
|||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
spoter_model = SPOTER(num_classes=100, hidden_dim=2*75)
|
spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2)
|
||||||
spoter_model.train(True)
|
spoter_model.train(True)
|
||||||
spoter_model.to(device)
|
spoter_model.to(device)
|
||||||
|
|
||||||
@@ -56,53 +57,54 @@ def train():
|
|||||||
|
|
||||||
|
|
||||||
train_acc, val_acc = 0, 0
|
train_acc, val_acc = 0, 0
|
||||||
losses, train_accs, val_accs = [], [], []
|
|
||||||
lr_progress = []
|
lr_progress = []
|
||||||
top_train_acc, top_val_acc = 0, 0
|
top_train_acc, top_val_acc = 0, 0
|
||||||
checkpoint_index = 0
|
checkpoint_index = 0
|
||||||
|
|
||||||
for epoch in range(100):
|
for epoch in range(100):
|
||||||
|
|
||||||
|
running_loss = 0.0
|
||||||
|
pred_correct, pred_all = 0, 0
|
||||||
|
|
||||||
# train
|
# train
|
||||||
for i, (inputs, labels) in enumerate(train_loader):
|
for i, (inputs, labels) in enumerate(train_loader):
|
||||||
inputs = inputs.squeeze(0).to(device)
|
inputs = inputs.squeeze(0).to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device, dtype=torch.long)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
outputs = spoter_model(inputs).expand(1, -1, -1)
|
outputs = spoter_model(inputs).expand(1, -1, -1)
|
||||||
loss = criterion(outputs[0], labels)
|
loss = criterion(outputs[0], labels)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
running_loss += loss
|
||||||
|
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
||||||
train_acc = (predicted == labels).sum().item() / labels.size(0)
|
pred_correct += 1
|
||||||
|
pred_all += 1
|
||||||
losses.append(loss.item())
|
|
||||||
train_accs.append(train_acc)
|
|
||||||
|
|
||||||
if i % 100 == 0:
|
if i % 100 == 0:
|
||||||
print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}")
|
print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
|
||||||
|
|
||||||
|
if scheduler:
|
||||||
|
scheduler.step(running_loss.item() / len(train_loader))
|
||||||
|
|
||||||
# validate
|
# validate
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i, (inputs, labels) in enumerate(val_loader):
|
for i, (inputs, labels) in enumerate(val_loader):
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.squeeze(0).to(device)
|
||||||
labels = labels.to(device)
|
labels = labels.to(device)
|
||||||
|
|
||||||
outputs = spoter_model(inputs)
|
outputs = spoter_model(inputs)
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
val_acc = (predicted == labels).sum().item() / labels.size(0)
|
val_acc = (predicted == labels).sum().item() / labels.size(0)
|
||||||
|
|
||||||
val_accs.append(val_acc)
|
|
||||||
|
|
||||||
scheduler.step(loss)
|
|
||||||
|
|
||||||
# save checkpoint
|
# save checkpoint
|
||||||
if val_acc > top_val_acc:
|
# if val_acc > top_val_acc:
|
||||||
top_val_acc = val_acc
|
# top_val_acc = val_acc
|
||||||
top_train_acc = train_acc
|
# top_train_acc = train_acc
|
||||||
checkpoint_index = epoch
|
# checkpoint_index = epoch
|
||||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
# torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||||
|
|
||||||
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
|
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
|
||||||
lr_progress.append(optimizer.param_groups[0]['lr'])
|
lr_progress.append(optimizer.param_groups[0]['lr'])
|
||||||
|
|||||||
Reference in New Issue
Block a user