Initial codebase (#1)
* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
This commit is contained in:
105
training/online_batch_mining.py
Normal file
105
training/online_batch_mining.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks
|
||||
|
||||
# Adapted from https://qdrant.tech/articles/triplet-loss/
|
||||
|
||||
|
||||
class BatchAllTripletLoss(nn.Module):
|
||||
"""Uses all valid triplets to compute Triplet loss
|
||||
Args:
|
||||
margin: Margin value in the Triplet Loss equation
|
||||
"""
|
||||
|
||||
def __init__(self, device, margin=1., filter_easy_triplets=True):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
self.device = device
|
||||
self.filter_easy_triplets = filter_easy_triplets
|
||||
|
||||
def get_triplet_mask(self, labels):
|
||||
"""compute a mask for valid triplets
|
||||
Args:
|
||||
labels: Batch of integer labels. shape: (batch_size,)
|
||||
Returns:
|
||||
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
|
||||
A triplet is valid if:
|
||||
`labels[i] == labels[j] and labels[i] != labels[k]`
|
||||
and `i`, `j`, `k` are different.
|
||||
"""
|
||||
# step 1 - get a mask for distinct indices
|
||||
|
||||
# shape: (batch_size, batch_size)
|
||||
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
|
||||
indices_not_equal = torch.logical_not(indices_equal)
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
# shape: (1, batch_size, batch_size)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
# Shape: (batch_size, batch_size, batch_size)
|
||||
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
|
||||
|
||||
# step 2 - get a mask for valid anchor-positive-negative triplets
|
||||
|
||||
# shape: (batch_size, batch_size)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
i_equal_j = labels_equal.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
i_equal_k = labels_equal.unsqueeze(1)
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
|
||||
|
||||
# step 3 - combine two masks
|
||||
mask = torch.logical_and(distinct_indices, valid_indices)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(self, embeddings, labels, filter_easy_triplets=True):
|
||||
"""computes loss value.
|
||||
Args:
|
||||
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
|
||||
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
|
||||
Returns:
|
||||
Scalar loss value.
|
||||
"""
|
||||
# step 1 - get distance matrix
|
||||
# shape: (batch_size, batch_size)
|
||||
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
||||
|
||||
# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix
|
||||
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
anchor_positive_dists = distance_matrix.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
anchor_negative_dists = distance_matrix.unsqueeze(1)
|
||||
# get loss values for all possible n^3 triplets
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
|
||||
|
||||
# step 3 - filter out invalid or easy triplets by setting their loss values to 0
|
||||
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
mask = self.get_triplet_mask(labels)
|
||||
valid_triplets = mask.sum()
|
||||
triplet_loss *= mask.to(self.device)
|
||||
# easy triplets have negative loss values
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
if self.filter_easy_triplets:
|
||||
# step 4 - compute scalar loss value by averaging positive losses
|
||||
num_positive_losses = (triplet_loss > eps).float().sum()
|
||||
# We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size)
|
||||
# The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets
|
||||
percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0])
|
||||
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor
|
||||
return triplet_loss, valid_triplets, int(num_positive_losses)
|
||||
else:
|
||||
triplet_loss = triplet_loss.sum() / (valid_triplets + eps)
|
||||
return triplet_loss, valid_triplets, valid_triplets
|
||||
Reference in New Issue
Block a user