import argparse import logging import os import random from pathlib import Path import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from augmentations import MirrorKeypoints from datasets.finger_spelling_dataset import FingerSpellingDataset from datasets.wlasl_dataset import WLASLDataset from identifiers import LANDMARKS from keypoint_extractor import KeypointExtractor from model import SPOTER def train(): random.seed(379) np.random.seed(379) os.environ['PYTHONHASHSEED'] = str(379) torch.manual_seed(379) torch.cuda.manual_seed(379) torch.cuda.manual_seed_all(379) torch.backends.cudnn.deterministic = True g = torch.Generator() g.manual_seed(379) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2) spoter_model.train(True) spoter_model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(spoter_model.parameters(), lr=0.0001, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5) # TODO: create paths for checkpoints # TODO: transformations + augmentations k = KeypointExtractor("data/fingerspelling/data/") transform = transforms.Compose([MirrorKeypoints()]) train_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="train", transform=transform) train_loader = DataLoader(train_set, shuffle=True, generator=g) val_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="val") val_loader = DataLoader(val_set, shuffle=True, generator=g) train_acc, val_acc = 0, 0 lr_progress = [] top_train_acc, top_val_acc = 0, 0 checkpoint_index = 0 for epoch in range(100): running_loss = 0.0 pred_correct, pred_all = 0, 0 # train for i, (inputs, labels) in enumerate(train_loader): inputs = inputs.squeeze(0).to(device) labels = labels.to(device, dtype=torch.long) optimizer.zero_grad() outputs = spoter_model(inputs).expand(1, -1, -1) loss = criterion(outputs[0], labels) loss.backward() optimizer.step() running_loss += loss if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]): pred_correct += 1 pred_all += 1 # if i % 100 == 0: # print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}") if scheduler: scheduler.step(running_loss.item() / len(train_loader)) # validate and print val acc val_pred_correct, val_pred_all = 0, 0 with torch.no_grad(): for i, (inputs, labels) in enumerate(val_loader): inputs = inputs.squeeze(0).to(device) labels = labels.to(device, dtype=torch.long) outputs = spoter_model(inputs).expand(1, -1, -1) if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]): val_pred_correct += 1 val_pred_all += 1 val_acc = (val_pred_correct / val_pred_all) print(f"Epoch: {epoch} | Train Acc: {(pred_correct / pred_all)} | Val Acc: {val_acc}") # save checkpoint if val_acc > top_val_acc: top_val_acc = val_acc top_train_acc = train_acc checkpoint_index = epoch torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") lr_progress.append(optimizer.param_groups[0]['lr']) print(f"Best val acc: {top_val_acc} | Best train acc: {top_train_acc} | Epoch: {checkpoint_index}") train()