Added training loop and model

This commit is contained in:
2023-02-21 23:24:51 +00:00
parent 1e05c02a7e
commit 98f29f683e
6 changed files with 191 additions and 10 deletions

110
src/train.py Normal file
View File

@@ -0,0 +1,110 @@
import os
import argparse
import random
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
from keypoint_extractor import KeypointExtractor
from identifiers import LANDMARKS
from model import SPOTER
from dataset import WLASLDataset
def train():
random.seed(379)
np.random.seed(379)
os.environ['PYTHONHASHSEED'] = str(379)
torch.manual_seed(379)
torch.cuda.manual_seed(379)
torch.cuda.manual_seed_all(379)
torch.backends.cudnn.deterministic = True
g = torch.Generator()
g.manual_seed(379)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=100, hidden_dim=2*75)
spoter_model.train(True)
spoter_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
# TODO: create paths for checkpoints
# TODO: transformations + augmentations
k = KeypointExtractor("data/videos/")
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
train_loader = DataLoader(train_set, shuffle=True, generator=g)
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
val_loader = DataLoader(val_set, shuffle=True, generator=g)
test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test")
test_loader = DataLoader(test_set, shuffle=True, generator=g)
train_acc, val_acc = 0, 0
losses, train_accs, val_accs = [], [], []
lr_progress = []
top_train_acc, top_val_acc = 0, 0
checkpoint_index = 0
for epoch in range(100):
# train
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = spoter_model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels)
loss.backward()
optimizer.step()
_, predicted = torch.max(outputs.data, 1)
train_acc = (predicted == labels).sum().item() / labels.size(0)
losses.append(loss.item())
train_accs.append(train_acc)
if i % 100 == 0:
print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}")
# validate
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = spoter_model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_acc = (predicted == labels).sum().item() / labels.size(0)
val_accs.append(val_acc)
scheduler.step(loss)
# save checkpoint
if val_acc > top_val_acc:
top_val_acc = val_acc
top_train_acc = train_acc
checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
lr_progress.append(optimizer.param_groups[0]['lr'])
train()