First training

This commit is contained in:
Victor Mylle
2023-03-02 11:18:57 +00:00
parent baeafe8c49
commit 246595780c
8 changed files with 307 additions and 46 deletions

View File

@@ -13,6 +13,8 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from augmentations import MirrorKeypoints
from datasets.finger_spelling_dataset import FingerSpellingDataset
from datasets.wlasl_dataset import WLASLDataset
from identifiers import LANDMARKS
from keypoint_extractor import KeypointExtractor
@@ -32,30 +34,28 @@ def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2)
spoter_model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
spoter_model.train(True)
spoter_model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(spoter_model.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.SGD(spoter_model.parameters(), lr=0.0001, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5)
# TODO: create paths for checkpoints
# TODO: transformations + augmentations
k = KeypointExtractor("data/videos/")
k = KeypointExtractor("data/fingerspelling/data/")
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
transform = transforms.Compose([MirrorKeypoints()])
train_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="train", transform=transform)
train_loader = DataLoader(train_set, shuffle=True, generator=g)
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
val_set = FingerSpellingDataset("data/fingerspelling/data/", k, keypoints_identifier=LANDMARKS, subset="val")
val_loader = DataLoader(val_set, shuffle=True, generator=g)
test_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="test")
test_loader = DataLoader(test_set, shuffle=True, generator=g)
train_acc, val_acc = 0, 0
lr_progress = []
top_train_acc, top_val_acc = 0, 0
@@ -82,31 +82,39 @@ def train():
pred_correct += 1
pred_all += 1
if i % 100 == 0:
print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
# if i % 100 == 0:
# print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
if scheduler:
scheduler.step(running_loss.item() / len(train_loader))
# validate
# validate and print val acc
val_pred_correct, val_pred_all = 0, 0
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device)
labels = labels.to(device, dtype=torch.long)
outputs = spoter_model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_acc = (predicted == labels).sum().item() / labels.size(0)
outputs = spoter_model(inputs).expand(1, -1, -1)
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
val_pred_correct += 1
val_pred_all += 1
val_acc = (val_pred_correct / val_pred_all)
print(f"Epoch: {epoch} | Train Acc: {(pred_correct / pred_all)} | Val Acc: {val_acc}")
# save checkpoint
# if val_acc > top_val_acc:
# top_val_acc = val_acc
# top_train_acc = train_acc
# checkpoint_index = epoch
# torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
if val_acc > top_val_acc:
top_val_acc = val_acc
top_train_acc = train_acc
checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
lr_progress.append(optimizer.param_groups[0]['lr'])
print(f"Best val acc: {top_val_acc} | Best train acc: {top_train_acc} | Epoch: {checkpoint_index}")
train()