Fingerspelling embedding + ClearML

This commit is contained in:
Victor Mylle
2023-05-21 20:30:12 +00:00
parent 2cbf11eb00
commit bd2b848eac
26 changed files with 2465 additions and 176 deletions

View File

@@ -15,7 +15,7 @@ from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
import copy
import numpy as np
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
train_epoch_embedding_online, evaluate_embedding
@@ -32,7 +32,7 @@ except ImportError:
pass
PROJECT_NAME = "spoter"
PROJECT_NAME = "SpoterEmbedding"
CLEARML = "clearml"
@@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
# Construct the model
if not args.classification_model:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout
)
# if finetune, load the weights from the classification model
if args.finetune:
checkpoint = torch.load(args.checkpoint_path, map_location=device)
slrt_model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
dropout=checkpoint["config_args"].dropout,
)
else:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout,
)
model_type = 'embed'
if args.hard_triplet_mining == "None":
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
@@ -233,6 +246,9 @@ def train(args, tracker: Tracker):
val_accs.append(val_acc)
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
create_embedding_scatter_plots(tracker, slrt_model, train_loader, val_loader, device, id_to_label, epoch,
top_model_name)
logger.info(f"Epoch time: {datetime.now() - start_time}")
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))