Files
sign-predictor/src/train.py
2023-02-27 13:34:26 +00:00

112 lines
3.6 KiB
Python

import argparse
import logging
import os
import random
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets.wlasl_dataset import WLASLDataset
from identifiers import LANDMARKS
from keypoint_extractor import KeypointExtractor
from model import SPOTER
def train():
random.seed(379)
np.random.seed(379)
os.environ['PYTHONHASHSEED'] = str(379)
torch.manual_seed(379)
torch.cuda.manual_seed(379)
torch.cuda.manual_seed_all(379)
torch.backends.cudnn.deterministic = True
g = torch.Generator()
g.manual_seed(379)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2)
spoter_model.train(True)
spoter_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
# TODO: create paths for checkpoints
# TODO: transformations + augmentations
k = KeypointExtractor("data/videos/")
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
train_loader = DataLoader(train_set, shuffle=True, generator=g)
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
val_loader = DataLoader(val_set, shuffle=True, generator=g)
test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test")
test_loader = DataLoader(test_set, shuffle=True, generator=g)
train_acc, val_acc = 0, 0
lr_progress = []
top_train_acc, top_val_acc = 0, 0
checkpoint_index = 0
for epoch in range(100):
running_loss = 0.0
pred_correct, pred_all = 0, 0
# train
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = spoter_model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels)
loss.backward()
optimizer.step()
running_loss += loss
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
pred_correct += 1
pred_all += 1
if i % 100 == 0:
print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
if scheduler:
scheduler.step(running_loss.item() / len(train_loader))
# validate
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device)
outputs = spoter_model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_acc = (predicted == labels).sum().item() / labels.size(0)
# save checkpoint
# if val_acc > top_val_acc:
# top_val_acc = val_acc
# top_train_acc = train_acc
# checkpoint_index = epoch
# torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
lr_progress.append(optimizer.param_groups[0]['lr'])
train()