Initial codebase (#1)

* Add project code

* Logger improvements

* Improvements to web demo code

* added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py

* Fix rotation augmentation

* fixed error in docstring, and removed unnecessary replace -1 -> 0

* Readme updates

* Share base notebooks

* Add notebooks and unify for different datasets

* requirements update

* fixes

* Make evaluate more deterministic

* Allow training with clearml

* refactor preprocessing and apply linter

* Minor fixes

* Minor notebook tweaks

* Readme updates

* Fix PR comments

* Remove unneeded code

* Add banner to Readme

---------

Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
This commit is contained in:
Mathias Claassen
2023-03-03 10:07:54 -03:00
committed by GitHub
parent 661e4bbc03
commit 81bbf66aab
49 changed files with 4205 additions and 0 deletions

0
training/__init__.py Normal file
View File

215
training/batch_sorter.py Normal file
View File

@@ -0,0 +1,215 @@
import logging
from datetime import datetime
import numpy as np
from typing import Optional
from .batching_scheduler import BatchingScheduler
import torch
logger = logging.getLogger("BatchGrouper")
class BatchGrouper:
"""
Will cluster all `total_items` into `max_groups` clusters based on distances in
`sorted_dists`. Each group has `mini_batch_size` elements and these elements are just integers in
range 0...total_items.
Distances between these items are expected to be scaled to 0...1 in a way that distances for two items in the
same class are higher if closer to 1, while distances between elements of different classes are higher if closer
to 0.
The logic is picking the highest value distance and assigning both items to the same cluster/group if possible.
This might include merging 2 clusters.
There are a few threshold to limit the computational cost. If the scaled distance between a pair is below
`dist_threshold`, or more than `assign_threshold` percent of items have been assigned to the groups, we stop and
assign the remanining items to the groups that have space left.
"""
# Counters
next_group = 0
items_assigned = 0
# Thresholds
dist_threshold = 0.5
assign_threshold = 0.80
def __init__(self, sorted_dists, total_items, mini_batch_size=32, dist_threshold=0.5, assign_threshold=0.8) -> None:
self.sorted_dists = sorted_dists
self.total_items = total_items
self.mini_batch_size = mini_batch_size
self.max_groups = int(total_items / mini_batch_size)
self.groups = {}
self.item_to_group = {}
self.items_assigned = 0
self.next_group = 0
self.dist_threshold = dist_threshold
self.assign_threshold = assign_threshold
def cluster_items(self):
"""Main function of this class. Does the clustering explained in class docstring.
:raises e: _description_
:return _type_: _description_
"""
for i in range(self.sorted_dists.shape[-1]): # and some other conditions are unmet
a, b, dist = self.sorted_dists[:, i]
a, b = int(a), int(b)
if dist < self.dist_threshold or self.items_assigned > self.total_items * self.assign_threshold:
logger.info(f"Breaking with dist: {dist}, and {self.items_assigned} items assigned")
break
if a not in self.item_to_group and b not in self.item_to_group:
g = self.create_or_get_group()
self.assign_group(a, g)
self.assign_group(b, g)
elif a not in self.item_to_group:
if not self.group_is_full(self.item_to_group[b]):
self.assign_group(a, self.item_to_group[b])
elif b not in self.item_to_group:
if not self.group_is_full(self.item_to_group[a]):
self.assign_group(b, self.item_to_group[a])
else:
grp_a = self.item_to_group[a]
grp_b = self.item_to_group[b]
self.merge_groups(grp_a, grp_b)
self.assign_remaining_items()
return list(np.concatenate(list(self.groups.values())).flat)
def assign_group(self, item, group):
"""Assigns `item` to group `group`
"""
self.item_to_group[item] = group
self.groups[group].append(item)
self.items_assigned += 1
def create_or_get_group(self):
"""Creates a new group if current group count is less than max_groups.
Otherwise returns first group with space left.
:return int: The group id
"""
if self.next_group < self.max_groups:
group = self.next_group
self.groups[group] = []
self.next_group += 1
else:
for i in range(self.next_group):
if len(self.groups[i]) <= self.mini_batch_size - 2:
group = i
break # out of the for loop
return group
def group_is_full(self, group):
return len(self.groups[group]) == self.mini_batch_size
def can_merge_groups(self, grp_a, grp_b):
return grp_a != grp_b and (len(self.groups[grp_a]) + len(self.groups[grp_b]) < self.mini_batch_size)
def merge_groups(self, grp_a, grp_b):
"""Will merge two groups together, if possible. Otherwise does nothing.
"""
if grp_a > grp_b:
grp_a, grp_b = grp_b, grp_a
if self.can_merge_groups(grp_a, grp_b):
logger.debug(f"MERGE {grp_a} with {grp_b}: {len(self.groups[grp_a])} {len(self.groups[grp_b])}")
for b in self.groups[grp_b]:
self.item_to_group[b] = grp_a
self.groups[grp_a].extend(self.groups[grp_b])
self.groups[grp_b] = []
self.replace_group(grp_b)
def replace_group(self, group):
"""Replace a group with the last one in the list
:param int group: Group to replace
"""
grp_to_change = self.next_group - 1
if grp_to_change != group:
for item in self.groups[grp_to_change]:
self.item_to_group[item] = group
self.groups[group] = self.groups[grp_to_change]
del self.groups[grp_to_change]
self.next_group -= 1
def assign_remaining_items(self):
""" Assign remaining items into groups
"""
grp_pointer = 0
i = 0
logger.info(f"Assigning rest of items: {self.items_assigned} of {self.total_items}")
while i < self.total_items:
if i not in self.item_to_group:
if grp_pointer not in self.groups:
# This would happen if a group is still empty at this stage
assert grp_pointer < self.max_groups
new_group = self.create_or_get_group()
assert new_group == grp_pointer
if len(self.groups[grp_pointer]) < self.mini_batch_size:
self.assign_group(i, grp_pointer)
i += 1
else:
grp_pointer += 1
else:
i += 1
def get_dist_tuple_list(dist_matrix):
batch_size = dist_matrix.size()[0]
indices = torch.tril_indices(batch_size, batch_size, offset=-1)
values = dist_matrix[indices[0], indices[1]].cpu()
return torch.cat([indices, values.unsqueeze(0)], dim=0)
def get_scaled_distances(embeddings, labels, device, same_label_factor=1):
"""Returns distance matrix between all embeddings scaled to the 0-1 range where 0 is good and 1 is bad.
This means that small distances for embeddings of the same class will be close to 0 while small distances for
embeddings of different classes will be close to 1
:param _type_ embeddings: Embeddings of batch items
:param _type_ labels: Labels associated to the embeddings
:param _type_ device: Device to run on (cuda or cpu)
:param int same_label_factor: Multiplies the weight of same-class distances allowing to give more or less importance
to these compared to distinct-class distances, defaults to 1 (which means equal weight)
:return torch.Tensor: Scaled distance matrix
"""
# Get pairwise distance matrix
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
# Get list of tuples with emb_A, emb_B, dist ordered by greater for same label and smaller for diff label
# shape: (batch_size, batch_size)
labels = labels.to(device)
labels_equal = (labels.unsqueeze(0) == labels.unsqueeze(1)).squeeze()
labels_distinct = torch.logical_not(labels_equal)
pos_dist = distance_matrix * labels_equal
neg_dist = distance_matrix * labels_distinct
# Use some scaling to bring both to a range of 0-1
pos_max = pos_dist.max()
neg_max = neg_dist.max()
# Closer to 1 is harder
pos_dist = pos_dist / pos_max * same_label_factor
neg_dist = 1 * labels_distinct - (neg_dist / neg_max)
return pos_dist + neg_dist
def sort_batches(inputs, labels, masks, embeddings, device, mini_batch_size=32,
scheduler: Optional[BatchingScheduler] = None):
start = datetime.now()
same_label_factor = scheduler.get_scaling_same_label_factor() if scheduler else 1
scaled_dist = get_scaled_distances(embeddings, labels, device, same_label_factor)
# Get vector of (row, column, dist)
dist_list = get_dist_tuple_list(scaled_dist)
dist_list = dist_list.cpu().detach().numpy()
# Sort distances descending by last row
sorted_dists = dist_list[:, dist_list[-1, :].argsort()[::-1]]
# Loop through list assigning both items to same group
dist_threshold = scheduler.get_dist_threshold() if scheduler else 0.5
grouper = BatchGrouper(sorted_dists, total_items=labels.size()[0], mini_batch_size=mini_batch_size,
dist_threshold=dist_threshold)
indices = torch.tensor(grouper.cluster_items()).type(torch.IntTensor)
final_inputs = torch.index_select(inputs, dim=0, index=indices)
final_labels = torch.index_select(labels, dim=0, index=indices)
final_masks = torch.index_select(masks, dim=0, index=indices)
logger.info(f"Batch sorting took: {datetime.now() - start}")
return final_inputs, final_labels, final_masks

View File

@@ -0,0 +1,62 @@
from collections import deque
import numpy as np
class BatchingScheduler():
""" This class acts as scheduler for the batching algorithm
"""
def __init__(self, decay_factor=0.8, min_threshold=0.2, triplets_threshold=10, cooldown=10) -> None:
# internal vars
self._step_count = 0
self._dist_threshold = 0.5
self._last_used_triplets = deque([], 5)
self._scaling_same_label_factor = 1
self._last_update_step = -10
# Parameters
self.decay_factor = decay_factor
self.min_threshold = min_threshold
self.triplets_threshold = triplets_threshold
self.cooldown = cooldown
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
"""
return {key: value for key, value in self.__dict__.items()}
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Args:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def step(self, used_triplets):
self._step_count += 1
self._last_used_triplets.append(used_triplets)
if (np.mean(self._last_used_triplets) < self.triplets_threshold and
self._last_update_step + self.cooldown <= self._step_count):
if self._dist_threshold > self.min_threshold:
print(f"Updating dist_threshold at {self._step_count} ({np.mean(self._last_used_triplets)})")
self.update_dist_threshold()
if self._scaling_same_label_factor > 0.6:
print(f"Updating scale factor at {self._step_count} ({np.mean(self._last_used_triplets)})")
self.update_scale_factor()
self._last_update_step = self._step_count
def update_scale_factor(self):
self._scaling_same_label_factor = max(self._scaling_same_label_factor * 0.9, 0.6)
print(f"Updating scaling factor to {self._scaling_same_label_factor}")
def update_dist_threshold(self):
self._dist_threshold = max(self.min_threshold, self._dist_threshold * self.decay_factor)
print(f"Updated dist_threshold to {self._dist_threshold}")
def get_dist_threshold(self) -> float:
return self._dist_threshold
def get_scaling_same_label_factor(self) -> float:
return self._scaling_same_label_factor

View File

@@ -0,0 +1,18 @@
import torch
class GaussianNoise(object):
def __init__(self, mean=0., std=1.):
self.std = std
self.mean = mean
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
if __name__ == "__main__":
pass

View File

@@ -0,0 +1,105 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks
# Adapted from https://qdrant.tech/articles/triplet-loss/
class BatchAllTripletLoss(nn.Module):
"""Uses all valid triplets to compute Triplet loss
Args:
margin: Margin value in the Triplet Loss equation
"""
def __init__(self, device, margin=1., filter_easy_triplets=True):
super().__init__()
self.margin = margin
self.device = device
self.filter_easy_triplets = filter_easy_triplets
def get_triplet_mask(self, labels):
"""compute a mask for valid triplets
Args:
labels: Batch of integer labels. shape: (batch_size,)
Returns:
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
A triplet is valid if:
`labels[i] == labels[j] and labels[i] != labels[k]`
and `i`, `j`, `k` are different.
"""
# step 1 - get a mask for distinct indices
# shape: (batch_size, batch_size)
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
indices_not_equal = torch.logical_not(indices_equal)
# shape: (batch_size, batch_size, 1)
i_not_equal_j = indices_not_equal.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
i_not_equal_k = indices_not_equal.unsqueeze(1)
# shape: (1, batch_size, batch_size)
j_not_equal_k = indices_not_equal.unsqueeze(0)
# Shape: (batch_size, batch_size, batch_size)
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
# step 2 - get a mask for valid anchor-positive-negative triplets
# shape: (batch_size, batch_size)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
# shape: (batch_size, batch_size, 1)
i_equal_j = labels_equal.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
i_equal_k = labels_equal.unsqueeze(1)
# shape: (batch_size, batch_size, batch_size)
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
# step 3 - combine two masks
mask = torch.logical_and(distinct_indices, valid_indices)
return mask
def forward(self, embeddings, labels, filter_easy_triplets=True):
"""computes loss value.
Args:
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
Returns:
Scalar loss value.
"""
# step 1 - get distance matrix
# shape: (batch_size, batch_size)
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix
# shape: (batch_size, batch_size, 1)
anchor_positive_dists = distance_matrix.unsqueeze(2)
# shape: (batch_size, 1, batch_size)
anchor_negative_dists = distance_matrix.unsqueeze(1)
# get loss values for all possible n^3 triplets
# shape: (batch_size, batch_size, batch_size)
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
# step 3 - filter out invalid or easy triplets by setting their loss values to 0
# shape: (batch_size, batch_size, batch_size)
mask = self.get_triplet_mask(labels)
valid_triplets = mask.sum()
triplet_loss *= mask.to(self.device)
# easy triplets have negative loss values
triplet_loss = F.relu(triplet_loss)
if self.filter_easy_triplets:
# step 4 - compute scalar loss value by averaging positive losses
num_positive_losses = (triplet_loss > eps).float().sum()
# We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size)
# The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets
percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0])
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor
return triplet_loss, valid_triplets, int(num_positive_losses)
else:
triplet_loss = triplet_loss.sum() / (valid_triplets + eps)
return triplet_loss, valid_triplets, valid_triplets

View File

@@ -0,0 +1,84 @@
import argparse
def get_default_args():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--experiment_name", type=str, default="lsa_64_spoter",
help="Name of the experiment after which the logs and plots will be named")
parser.add_argument("--num_classes", type=int, default=100, help="Number of classes to be recognized by the model")
parser.add_argument("--hidden_dim", type=int, default=108,
help="Hidden dimension of the underlying Transformer model")
parser.add_argument("--seed", type=int, default=379,
help="Seed with which to initialize all the random components of the training")
# Embeddings
parser.add_argument("--classification_model", action='store_true', default=False,
help="Select SPOTER model to train, pass only for original classification model")
parser.add_argument("--vector_length", type=int, default=32,
help="Number of features used in the embedding vector")
parser.add_argument("--epoch_iters", type=int, default=-1,
help="Iterations per epoch while training embeddings. Will loop through dataset once if -1")
parser.add_argument("--batch_size", type=int, default=32, help="Batch Size during training and validation")
parser.add_argument("--hard_triplet_mining", type=str, default=None,
help="Strategy to select hard triplets, options [None, in_batch]")
parser.add_argument("--triplet_loss_margin", type=float, default=1,
help="Margin used in triplet loss margin (See documentation)")
parser.add_argument("--normalize_embeddings", action='store_true', default=False,
help="Normalize model output to keep vector length to one")
parser.add_argument("--filter_easy_triplets", action='store_true', default=False,
help="Filter easy triplets in online in batch triplets")
# Data
parser.add_argument("--dataset_name", type=str, default="", help="Dataset name")
parser.add_argument("--dataset_project", type=str, default="Sign Language Recognition", help="Dataset project name")
parser.add_argument("--training_set_path", type=str, default="",
help="Path to the training dataset CSV file (relative to root dataset)")
parser.add_argument("--validation_set_path", type=str, default="", help="Path to the validation dataset CSV file")
parser.add_argument("--dataset_loader", type=str, default="local",
help="Dataset loader to use, options: [clearml, local]")
# Training hyperparameters
parser.add_argument("--epochs", type=int, default=1300, help="Number of epochs to train the model for")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the model training")
parser.add_argument("--dropout", type=float, default=0.1,
help="Dropout used in transformer layer")
parser.add_argument("--augmentations_prob", type=float, default=0.5, help="How often to use data augmentation")
# Checkpointing
parser.add_argument("--save_checkpoints_every", type=int, default=-1,
help="Determines every how many epochs the weight checkpoints are saved. If -1 only best model \
after final epoch")
# Optimizer
parser.add_argument("--optimizer", type=str, default="SGD",
help="Optimizer used during training, options: [SGD, ADAM]")
# Tracker
parser.add_argument("--tracker", type=str, default="none",
help="Experiment tracker to use, options: [clearml, none]")
# Scheduler
parser.add_argument("--scheduler_factor", type=float, default=0,
help="Factor for the ReduceLROnPlateau scheduler")
parser.add_argument("--scheduler_patience", type=int, default=10,
help="Patience for the ReduceLROnPlateau scheduler")
parser.add_argument("--scheduler_warmup", type=int, default=400,
help="Warmup epochs before scheduler starts")
# Gaussian noise normalization
parser.add_argument("--gaussian_mean", type=float, default=0, help="Mean parameter for Gaussian noise layer")
parser.add_argument("--gaussian_std", type=float, default=0.001,
help="Standard deviation parameter for Gaussian noise layer")
# Batch Sorting
parser.add_argument("--start_mining_hard", type=int, default=None, help="On which epoch to start hard mining")
parser.add_argument("--hard_mining_pre_batch_multipler", type=int, default=16,
help="How many batches should be computed at once")
parser.add_argument("--hard_mining_pre_batch_mining_count", type=int, default=5,
help="How many times to loop through a list of computed batches")
parser.add_argument("--hard_mining_scheduler_triplets_threshold", type=float, default=0,
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
distance threshold of the batch sorter")
return parser

71
training/train_utils.py Normal file
View File

@@ -0,0 +1,71 @@
import os
import random
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from models import embeddings_scatter_plot, embeddings_scatter_plot_splits
def train_setup(seed, experiment_name):
random.seed(seed)
np.random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
g = torch.Generator()
g.manual_seed(seed)
return g
def create_embedding_scatter_plots(tracker, model, train_loader, val_loader, device, id_to_label, epoch, model_name):
tsne_results, labels = embeddings_scatter_plot(model, train_loader, device, id_to_label, perplexity=40, n_iter=1000)
df = pd.DataFrame({'x': tsne_results[:, 0],
'y': tsne_results[:, 1],
'label': labels})
fig = px.scatter(df, y="y", x="x", color="label")
tracker.log_chart(
title="Training Scatter Plot with Best Model: " + model_name,
series="Scatter Plot",
iteration=epoch,
figure=fig
)
tsne_results, labels = embeddings_scatter_plot(model, val_loader, device, id_to_label, perplexity=40, n_iter=1000)
df = pd.DataFrame({'x': tsne_results[:, 0],
'y': tsne_results[:, 1],
'label': labels})
fig = px.scatter(df, y="y", x="x", color="label")
tracker.log_chart(
title="Validation Scatter Plot with Best Model: " + model_name,
series="Scatter Plot",
iteration=epoch,
figure=fig,
)
dataloaders = {'train': train_loader,
'val': val_loader}
splits = list(dataloaders.keys())
tsne_results_splits, labels_splits = embeddings_scatter_plot_splits(model, dataloaders,
device, id_to_label, perplexity=40, n_iter=1000)
tsne_results = np.vstack([tsne_results_splits[split] for split in splits])
labels = np.concatenate([labels_splits[split] for split in splits])
split = np.concatenate([[split]*len(labels_splits[split]) for split in splits])
df = pd.DataFrame({'x': tsne_results[:, 0],
'y': tsne_results[:, 1],
'label': labels,
'split': split})
fig = px.scatter(df, y="y", x="x", color="label", symbol='split')
tracker.log_chart(
title="Scatter Plot of train and val with Best Model: " + model_name,
series="Scatter Plot",
iteration=epoch,
figure=fig,
)