Fixed some bugs in the training loop, still no good results

This commit is contained in:
2023-02-23 12:06:26 +00:00
parent 98f29f683e
commit 97ede38e3a
5 changed files with 68 additions and 63 deletions

View File

@@ -1,22 +1,23 @@
import os
import argparse
import random
import logging
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from torchvision import transforms
from torch.utils.data import DataLoader
import os
import random
from pathlib import Path
from keypoint_extractor import KeypointExtractor
from identifiers import LANDMARKS
from model import SPOTER
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import WLASLDataset
from identifiers import LANDMARKS
from keypoint_extractor import KeypointExtractor
from model import SPOTER
def train():
random.seed(379)
@@ -31,7 +32,7 @@ def train():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
spoter_model = SPOTER(num_classes=100, hidden_dim=2*75)
spoter_model = SPOTER(num_classes=100, hidden_dim=len(LANDMARKS) *2)
spoter_model.train(True)
spoter_model.to(device)
@@ -47,7 +48,7 @@ def train():
train_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="train")
train_loader = DataLoader(train_set, shuffle=True, generator=g)
val_set = WLASLDataset("data/nslt_100.json", "data/missing.txt", k, keypoints_identifier=LANDMARKS, subset="val")
val_loader = DataLoader(val_set, shuffle=True, generator=g)
@@ -56,53 +57,54 @@ def train():
train_acc, val_acc = 0, 0
losses, train_accs, val_accs = [], [], []
lr_progress = []
top_train_acc, top_val_acc = 0, 0
checkpoint_index = 0
for epoch in range(100):
running_loss = 0.0
pred_correct, pred_all = 0, 0
# train
for i, (inputs, labels) in enumerate(train_loader):
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = spoter_model(inputs).expand(1, -1, -1)
loss = criterion(outputs[0], labels)
loss.backward()
optimizer.step()
running_loss += loss
_, predicted = torch.max(outputs.data, 1)
train_acc = (predicted == labels).sum().item() / labels.size(0)
losses.append(loss.item())
train_accs.append(train_acc)
if int(torch.argmax(torch.nn.functional.softmax(outputs, dim=2))) == int(labels[0]):
pred_correct += 1
pred_all += 1
if i % 100 == 0:
print(f"Epoch: {epoch} | Batch: {i} | Loss: {loss.item()} | Train Acc: {train_acc}")
print(f"Epoch: {epoch} | Batch: {i} | Loss: {running_loss.item()} | Train Acc: {(pred_correct / pred_all)}")
if scheduler:
scheduler.step(running_loss.item() / len(train_loader))
# validate
with torch.no_grad():
for i, (inputs, labels) in enumerate(val_loader):
inputs = inputs.to(device)
inputs = inputs.squeeze(0).to(device)
labels = labels.to(device)
outputs = spoter_model(inputs)
_, predicted = torch.max(outputs.data, 1)
val_acc = (predicted == labels).sum().item() / labels.size(0)
val_accs.append(val_acc)
scheduler.step(loss)
# save checkpoint
if val_acc > top_val_acc:
top_val_acc = val_acc
top_train_acc = train_acc
checkpoint_index = epoch
torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
# if val_acc > top_val_acc:
# top_val_acc = val_acc
# top_train_acc = train_acc
# checkpoint_index = epoch
# torch.save(spoter_model.state_dict(), f"checkpoints/spoter_{epoch}.pth")
print(f"Epoch: {epoch} | Train Acc: {train_acc} | Val Acc: {val_acc}")
lr_progress.append(optimizer.param_groups[0]['lr'])