Some fixes
This commit is contained in:
25
train.py
25
train.py
@@ -75,12 +75,25 @@ def train(args, tracker: Tracker):
|
||||
|
||||
# Construct the model
|
||||
if not args.classification_model:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout
|
||||
)
|
||||
# if finetune, load the weights from the classification model
|
||||
if args.finetune:
|
||||
checkpoint = torch.load(args.checkpoint, map_location=device)
|
||||
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
dropout=checkpoint["config_args"].dropout,
|
||||
)
|
||||
else:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout,
|
||||
)
|
||||
|
||||
|
||||
model_type = 'embed'
|
||||
if args.hard_triplet_mining == "None":
|
||||
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
||||
|
||||
Reference in New Issue
Block a user