Fingerspelling embedding + ClearML
This commit is contained in:
32
train.py
32
train.py
@@ -15,7 +15,7 @@ from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from pathlib import Path
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
|
||||
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
|
||||
train_epoch_embedding_online, evaluate_embedding
|
||||
@@ -32,7 +32,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
PROJECT_NAME = "spoter"
|
||||
PROJECT_NAME = "SpoterEmbedding"
|
||||
CLEARML = "clearml"
|
||||
|
||||
|
||||
@@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
|
||||
|
||||
# Construct the model
|
||||
if not args.classification_model:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout
|
||||
)
|
||||
# if finetune, load the weights from the classification model
|
||||
if args.finetune:
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location=device)
|
||||
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
dropout=checkpoint["config_args"].dropout,
|
||||
)
|
||||
else:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
|
||||
model_type = 'embed'
|
||||
if args.hard_triplet_mining == "None":
|
||||
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
||||
@@ -233,6 +246,9 @@ def train(args, tracker: Tracker):
|
||||
val_accs.append(val_acc)
|
||||
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
|
||||
|
||||
create_embedding_scatter_plots(tracker, slrt_model, train_loader, val_loader, device, id_to_label, epoch,
|
||||
top_model_name)
|
||||
|
||||
logger.info(f"Epoch time: {datetime.now() - start_time}")
|
||||
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
|
||||
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))
|
||||
|
||||
Reference in New Issue
Block a user