Files
sign-predictor/src/train.py
Victor Mylle 246595780c First training
2023-03-02 11:18:57 +00:00

120 lines
3.9 KiB
Python

import argparse
import logging
import os
import random
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from augmentations import MirrorKeypoints
from datasets.finger_spelling_dataset import FingerSpellingDataset
from datasets.wlasl_dataset import WLASLDataset
from identifiers import LANDMARKS
from keypoint_extractor import KeypointExtractor
from model import SPOTER
def train():
random.seed(379)
np.random.seed(379)
os.environ['PYTHONHASHSEED'] = str(379)
torch.manual_seed(379)
torch.cuda.manual_seed(379)
torch.cuda.manual_seed_all(379)
torch.backends.cudnn.deterministic = True
g = torch.Generator()
g.manual_seed(379)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
spoter_model.train(True)
spoter_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(spoter_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
# TODO: create paths for checkpoints
# TODO: transformations + augmentations
k = KeypointExtractor("data/fingerspelling/data/")
transform = transforms.Compose([MirrorKeypoints()])
train_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="train", transform=transform)
train_loader = DataLoader(train_set, shuffle=True, generator=g)
val_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="val")
val_loader = DataLoader(val_set, shuffle=True, generator=g)
train_acc, val_acc = 0, 0
lr_progress = []
top_train_acc, top_val_acc = 0, 0
checkpoint_index = 0
for epoch in range(100):
running_loss = 0.0
pred_correct, pred_all = 0, 0
# train
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = spoter_model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels)
loss.backward()
optimizer.step()
running_loss += loss
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
pred_correct += 1
pred_all += 1
# if i % 100 == 0:
# print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
if scheduler:
scheduler.step(running_loss.item() / len(train_loader))
# validate and print val acc
val_pred_correct, val_pred_all = 0, 0
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
outputs = spoter_model(inputs).expand(1, -1, -1)
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
val_pred_correct += 1
val_pred_all += 1
val_acc = (val_pred_correct / val_pred_all)
print(f"Epoch: {epoch} | Train Acc: {(pred_correct / pred_all)} | Val Acc: {val_acc}")
# save checkpoint
if val_acc > top_val_acc:
top_val_acc = val_acc
top_train_acc = train_acc
checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
lr_progress.append(optimizer.param_groups[0]['lr'])
print(f"Best val acc: {top_val_acc} | Best train acc: {top_train_acc} | Epoch: {checkpoint_index}")
train()