Fixed some bugs in the training loop, still no good results
This commit is contained in:
68
src/train.py
68
src/train.py
@@ -1,22 +1,23 @@
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
import logging
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from identifiers import LANDMARKS
|
||||
from model import SPOTER
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as ticker
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
from dataset import WLASLDataset
|
||||
from identifiers import LANDMARKS
|
||||
from keypoint_extractor import KeypointExtractor
|
||||
from model import SPOTER
|
||||
|
||||
|
||||
def train():
|
||||
random.seed(379)
|
||||
@@ -31,7 +32,7 @@ def train():
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
spoter_model = SPOTER(num_classes=100, hidden_dim=2*75)
|
||||
spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2)
|
||||
spoter_model.train(True)
|
||||
spoter_model.to(device)
|
||||
|
||||
@@ -47,7 +48,7 @@ def train():
|
||||
|
||||
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
|
||||
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
||||
|
||||
|
||||
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
|
||||
val_loader = DataLoader(val_set, shuffle=True, generator=g)
|
||||
|
||||
@@ -56,53 +57,54 @@ def train():
|
||||
|
||||
|
||||
train_acc, val_acc = 0, 0
|
||||
losses, train_accs, val_accs = [], [], []
|
||||
lr_progress = []
|
||||
top_train_acc, top_val_acc = 0, 0
|
||||
checkpoint_index = 0
|
||||
|
||||
for epoch in range(100):
|
||||
|
||||
running_loss = 0.0
|
||||
pred_correct, pred_all = 0, 0
|
||||
|
||||
# train
|
||||
for i, (inputs, labels) in enumerate(train_loader):
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device)
|
||||
labels = labels.to(device, dtype=torch.long)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = spoter_model(inputs).expand(1, -1, -1)
|
||||
loss = criterion(outputs[0], labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
running_loss += loss
|
||||
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
train_acc = (predicted == labels).sum().item() / labels.size(0)
|
||||
|
||||
losses.append(loss.item())
|
||||
train_accs.append(train_acc)
|
||||
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
|
||||
pred_correct += 1
|
||||
pred_all += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}")
|
||||
print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
|
||||
|
||||
if scheduler:
|
||||
scheduler.step(running_loss.item() / len(train_loader))
|
||||
|
||||
# validate
|
||||
with torch.no_grad():
|
||||
for i, (inputs, labels) in enumerate(val_loader):
|
||||
inputs = inputs.to(device)
|
||||
inputs = inputs.squeeze(0).to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
outputs = spoter_model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
val_acc = (predicted == labels).sum().item() / labels.size(0)
|
||||
|
||||
val_accs.append(val_acc)
|
||||
|
||||
scheduler.step(loss)
|
||||
|
||||
# save checkpoint
|
||||
if val_acc > top_val_acc:
|
||||
top_val_acc = val_acc
|
||||
top_train_acc = train_acc
|
||||
checkpoint_index = epoch
|
||||
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||
# if val_acc > top_val_acc:
|
||||
# top_val_acc = val_acc
|
||||
# top_train_acc = train_acc
|
||||
# checkpoint_index = epoch
|
||||
# torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
|
||||
|
||||
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
|
||||
lr_progress.append(optimizer.param_groups[0]['lr'])
|
||||
|
||||
Reference in New Issue
Block a user