import argparse import logging import os import random from pathlib import Path import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from datasets.wlasl_dataset import WLASLDataset from identifiers import LANDMARKS from keypoint_extractor import KeypointExtractor from model import SPOTER def train(): random.seed(379) np.random.seed(379) os.environ['PYTHONHASHSEED'] = str(379) torch.manual_seed(379) torch.cuda.manual_seed(379) torch.cuda.manual_seed_all(379) torch.backends.cudnn.deterministic = True g = torch.Generator() g.manual_seed(379) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2) spoter_model.train(True) spoter_model.to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5) # TODO: create paths for checkpoints # TODO: transformations + augmentations k = KeypointExtractor("data/videos/") train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train") train_loader = DataLoader(train_set, shuffle=True, generator=g) val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val") val_loader = DataLoader(val_set, shuffle=True, generator=g) test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test") test_loader = DataLoader(test_set, shuffle=True, generator=g) train_acc, val_acc = 0, 0 lr_progress = [] top_train_acc, top_val_acc = 0, 0 checkpoint_index = 0 for epoch in range(100): running_loss = 0.0 pred_correct, pred_all = 0, 0 # train for i, (inputs, labels) in enumerate(train_loader): inputs = inputs.squeeze(0).to(device) labels = labels.to(device, dtype=torch.long) optimizer.zero_grad() outputs = spoter_model(inputs).expand(1, -1, -1) loss = criterion(outputs[0], labels) loss.backward() optimizer.step() running_loss += loss if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]): pred_correct += 1 pred_all += 1 if i % 100 == 0: print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}") if scheduler: scheduler.step(running_loss.item() / len(train_loader)) # validate with torch.no_grad(): for i, (inputs, labels) in enumerate(val_loader): inputs = inputs.squeeze(0).to(device) labels = labels.to(device) outputs = spoter_model(inputs) _, predicted = torch.max(outputs.data, 1) val_acc = (predicted == labels).sum().item() / labels.size(0) # save checkpoint # if val_acc > top_val_acc: # top_val_acc = val_acc # top_train_acc = train_acc # checkpoint_index = epoch # torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth") print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}") lr_progress.append(optimizer.param_groups[0]['lr']) train()