120 lines
3.9 KiB
Python
120 lines
3.9 KiB
Python
import argparse
|
|
import logging
|
|
import os
|
|
import random
|
|
from pathlib import Path
|
|
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.ticker as ticker
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
|
|
from augmentations import MirrorKeypoints
|
|
from datasets.finger_spelling_dataset import FingerSpellingDataset
|
|
from datasets.wlasl_dataset import WLASLDataset
|
|
from identifiers import LANDMARKS
|
|
from keypoint_extractor import KeypointExtractor
|
|
from model import SPOTER
|
|
|
|
|
|
def train():
|
|
random.seed(379)
|
|
np.random.seed(379)
|
|
os.environ['PYTHONHASHSEED'] = str(379)
|
|
torch.manual_seed(379)
|
|
torch.cuda.manual_seed(379)
|
|
torch.cuda.manual_seed_all(379)
|
|
torch.backends.cudnn.deterministic = True
|
|
g = torch.Generator()
|
|
g.manual_seed(379)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
|
|
spoter_model.train(True)
|
|
spoter_model.to(device)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(spoter_model.parameters(), lr=0.0001, momentum=0.9)
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
|
|
|
|
# TODO: create paths for checkpoints
|
|
|
|
# TODO: transformations + augmentations
|
|
|
|
k = KeypointExtractor("data/fingerspelling/data/")
|
|
|
|
transform = transforms.Compose([MirrorKeypoints()])
|
|
|
|
train_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="train", transform=transform)
|
|
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
|
|
|
val_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="val")
|
|
val_loader = DataLoader(val_set, shuffle=True, generator=g)
|
|
|
|
train_acc, val_acc = 0, 0
|
|
lr_progress = []
|
|
top_train_acc, top_val_acc = 0, 0
|
|
checkpoint_index = 0
|
|
|
|
for epoch in range(100):
|
|
|
|
running_loss = 0.0
|
|
pred_correct, pred_all = 0, 0
|
|
|
|
# train
|
|
for i, (inputs, labels) in enumerate(train_loader):
|
|
inputs = inputs.squeeze(0).to(device)
|
|
labels = labels.to(device, dtype=torch.long)
|
|
|
|
optimizer.zero_grad()
|
|
outputs = spoter_model(inputs).expand(1, -1, -1)
|
|
loss = criterion(outputs[0], labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
running_loss += loss
|
|
|
|
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
|
pred_correct += 1
|
|
pred_all += 1
|
|
|
|
# if i % 100 == 0:
|
|
# print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
|
|
|
|
if scheduler:
|
|
scheduler.step(running_loss.item() / len(train_loader))
|
|
|
|
# validate and print val acc
|
|
val_pred_correct, val_pred_all = 0, 0
|
|
with torch.no_grad():
|
|
for i, (inputs, labels) in enumerate(val_loader):
|
|
inputs = inputs.squeeze(0).to(device)
|
|
labels = labels.to(device, dtype=torch.long)
|
|
|
|
outputs = spoter_model(inputs).expand(1, -1, -1)
|
|
|
|
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
|
val_pred_correct += 1
|
|
val_pred_all += 1
|
|
|
|
val_acc = (val_pred_correct / val_pred_all)
|
|
|
|
print(f"Epoch: {epoch} | Train Acc: {(pred_correct / pred_all)} | Val Acc: {val_acc}")
|
|
|
|
|
|
# save checkpoint
|
|
if val_acc > top_val_acc:
|
|
top_val_acc = val_acc
|
|
top_train_acc = train_acc
|
|
checkpoint_index = epoch
|
|
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
|
|
|
lr_progress.append(optimizer.param_groups[0]['lr'])
|
|
|
|
print(f"Best val acc: {top_val_acc} | Best train acc: {top_train_acc} | Epoch: {checkpoint_index}")
|
|
|
|
train() |