Some fixes

This commit is contained in:
2023-04-17 15:52:19 +00:00
parent 2e66cccf50
commit 2f7063b70d
6 changed files with 158 additions and 61 deletions

View File

@@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
# Construct the model
if not args.classification_model:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout
)
# if finetune, load the weights from the classification model
if args.finetune:
checkpoint = torch.load(args.checkpoint, map_location=device)
slrt_model = SPOTER_EMBEDDINGS(
features=checkpoint["config_args"].vector_length,
hidden_dim=checkpoint["config_args"].hidden_dim,
norm_emb=checkpoint["config_args"].normalize_embeddings,
dropout=checkpoint["config_args"].dropout,
)
else:
slrt_model = SPOTER_EMBEDDINGS(
features=args.vector_length,
hidden_dim=args.hidden_dim,
norm_emb=args.normalize_embeddings,
dropout=args.dropout,
)
model_type = 'embed'
if args.hard_triplet_mining == "None":
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)