Initial codebase (#1)
* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
This commit is contained in:
0
training/__init__.py
Normal file
0
training/__init__.py
Normal file
215
training/batch_sorter.py
Normal file
215
training/batch_sorter.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from .batching_scheduler import BatchingScheduler
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger("BatchGrouper")
|
||||
|
||||
|
||||
class BatchGrouper:
|
||||
"""
|
||||
Will cluster all `total_items` into `max_groups` clusters based on distances in
|
||||
`sorted_dists`. Each group has `mini_batch_size` elements and these elements are just integers in
|
||||
range 0...total_items.
|
||||
|
||||
Distances between these items are expected to be scaled to 0...1 in a way that distances for two items in the
|
||||
same class are higher if closer to 1, while distances between elements of different classes are higher if closer
|
||||
to 0.
|
||||
|
||||
The logic is picking the highest value distance and assigning both items to the same cluster/group if possible.
|
||||
This might include merging 2 clusters.
|
||||
There are a few threshold to limit the computational cost. If the scaled distance between a pair is below
|
||||
`dist_threshold`, or more than `assign_threshold` percent of items have been assigned to the groups, we stop and
|
||||
assign the remanining items to the groups that have space left.
|
||||
"""
|
||||
# Counters
|
||||
next_group = 0
|
||||
items_assigned = 0
|
||||
|
||||
# Thresholds
|
||||
dist_threshold = 0.5
|
||||
assign_threshold = 0.80
|
||||
|
||||
def __init__(self, sorted_dists, total_items, mini_batch_size=32, dist_threshold=0.5, assign_threshold=0.8) -> None:
|
||||
self.sorted_dists = sorted_dists
|
||||
self.total_items = total_items
|
||||
self.mini_batch_size = mini_batch_size
|
||||
self.max_groups = int(total_items / mini_batch_size)
|
||||
self.groups = {}
|
||||
self.item_to_group = {}
|
||||
self.items_assigned = 0
|
||||
self.next_group = 0
|
||||
self.dist_threshold = dist_threshold
|
||||
self.assign_threshold = assign_threshold
|
||||
|
||||
def cluster_items(self):
|
||||
"""Main function of this class. Does the clustering explained in class docstring.
|
||||
|
||||
:raises e: _description_
|
||||
:return _type_: _description_
|
||||
"""
|
||||
for i in range(self.sorted_dists.shape[-1]): # and some other conditions are unmet
|
||||
a, b, dist = self.sorted_dists[:, i]
|
||||
a, b = int(a), int(b)
|
||||
if dist < self.dist_threshold or self.items_assigned > self.total_items * self.assign_threshold:
|
||||
logger.info(f"Breaking with dist: {dist}, and {self.items_assigned} items assigned")
|
||||
break
|
||||
if a not in self.item_to_group and b not in self.item_to_group:
|
||||
g = self.create_or_get_group()
|
||||
self.assign_group(a, g)
|
||||
self.assign_group(b, g)
|
||||
elif a not in self.item_to_group:
|
||||
if not self.group_is_full(self.item_to_group[b]):
|
||||
self.assign_group(a, self.item_to_group[b])
|
||||
elif b not in self.item_to_group:
|
||||
if not self.group_is_full(self.item_to_group[a]):
|
||||
self.assign_group(b, self.item_to_group[a])
|
||||
else:
|
||||
grp_a = self.item_to_group[a]
|
||||
grp_b = self.item_to_group[b]
|
||||
self.merge_groups(grp_a, grp_b)
|
||||
self.assign_remaining_items()
|
||||
return list(np.concatenate(list(self.groups.values())).flat)
|
||||
|
||||
def assign_group(self, item, group):
|
||||
"""Assigns `item` to group `group`
|
||||
"""
|
||||
self.item_to_group[item] = group
|
||||
self.groups[group].append(item)
|
||||
self.items_assigned += 1
|
||||
|
||||
def create_or_get_group(self):
|
||||
"""Creates a new group if current group count is less than max_groups.
|
||||
Otherwise returns first group with space left.
|
||||
|
||||
:return int: The group id
|
||||
"""
|
||||
if self.next_group < self.max_groups:
|
||||
group = self.next_group
|
||||
self.groups[group] = []
|
||||
self.next_group += 1
|
||||
else:
|
||||
for i in range(self.next_group):
|
||||
if len(self.groups[i]) <= self.mini_batch_size - 2:
|
||||
group = i
|
||||
break # out of the for loop
|
||||
return group
|
||||
|
||||
def group_is_full(self, group):
|
||||
return len(self.groups[group]) == self.mini_batch_size
|
||||
|
||||
def can_merge_groups(self, grp_a, grp_b):
|
||||
return grp_a != grp_b and (len(self.groups[grp_a]) + len(self.groups[grp_b]) < self.mini_batch_size)
|
||||
|
||||
def merge_groups(self, grp_a, grp_b):
|
||||
"""Will merge two groups together, if possible. Otherwise does nothing.
|
||||
"""
|
||||
if grp_a > grp_b:
|
||||
grp_a, grp_b = grp_b, grp_a
|
||||
if self.can_merge_groups(grp_a, grp_b):
|
||||
logger.debug(f"MERGE {grp_a} with {grp_b}: {len(self.groups[grp_a])} {len(self.groups[grp_b])}")
|
||||
for b in self.groups[grp_b]:
|
||||
self.item_to_group[b] = grp_a
|
||||
self.groups[grp_a].extend(self.groups[grp_b])
|
||||
self.groups[grp_b] = []
|
||||
self.replace_group(grp_b)
|
||||
|
||||
def replace_group(self, group):
|
||||
"""Replace a group with the last one in the list
|
||||
|
||||
:param int group: Group to replace
|
||||
"""
|
||||
grp_to_change = self.next_group - 1
|
||||
if grp_to_change != group:
|
||||
for item in self.groups[grp_to_change]:
|
||||
self.item_to_group[item] = group
|
||||
self.groups[group] = self.groups[grp_to_change]
|
||||
del self.groups[grp_to_change]
|
||||
self.next_group -= 1
|
||||
|
||||
def assign_remaining_items(self):
|
||||
""" Assign remaining items into groups
|
||||
"""
|
||||
grp_pointer = 0
|
||||
i = 0
|
||||
logger.info(f"Assigning rest of items: {self.items_assigned} of {self.total_items}")
|
||||
while i < self.total_items:
|
||||
if i not in self.item_to_group:
|
||||
if grp_pointer not in self.groups:
|
||||
# This would happen if a group is still empty at this stage
|
||||
assert grp_pointer < self.max_groups
|
||||
new_group = self.create_or_get_group()
|
||||
assert new_group == grp_pointer
|
||||
if len(self.groups[grp_pointer]) < self.mini_batch_size:
|
||||
self.assign_group(i, grp_pointer)
|
||||
i += 1
|
||||
else:
|
||||
grp_pointer += 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
|
||||
def get_dist_tuple_list(dist_matrix):
|
||||
batch_size = dist_matrix.size()[0]
|
||||
indices = torch.tril_indices(batch_size, batch_size, offset=-1)
|
||||
values = dist_matrix[indices[0], indices[1]].cpu()
|
||||
return torch.cat([indices, values.unsqueeze(0)], dim=0)
|
||||
|
||||
|
||||
def get_scaled_distances(embeddings, labels, device, same_label_factor=1):
|
||||
"""Returns distance matrix between all embeddings scaled to the 0-1 range where 0 is good and 1 is bad.
|
||||
This means that small distances for embeddings of the same class will be close to 0 while small distances for
|
||||
embeddings of different classes will be close to 1
|
||||
|
||||
:param _type_ embeddings: Embeddings of batch items
|
||||
:param _type_ labels: Labels associated to the embeddings
|
||||
:param _type_ device: Device to run on (cuda or cpu)
|
||||
:param int same_label_factor: Multiplies the weight of same-class distances allowing to give more or less importance
|
||||
to these compared to distinct-class distances, defaults to 1 (which means equal weight)
|
||||
:return torch.Tensor: Scaled distance matrix
|
||||
"""
|
||||
# Get pairwise distance matrix
|
||||
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
||||
# Get list of tuples with emb_A, emb_B, dist ordered by greater for same label and smaller for diff label
|
||||
# shape: (batch_size, batch_size)
|
||||
labels = labels.to(device)
|
||||
labels_equal = (labels.unsqueeze(0) == labels.unsqueeze(1)).squeeze()
|
||||
labels_distinct = torch.logical_not(labels_equal)
|
||||
pos_dist = distance_matrix * labels_equal
|
||||
neg_dist = distance_matrix * labels_distinct
|
||||
|
||||
# Use some scaling to bring both to a range of 0-1
|
||||
pos_max = pos_dist.max()
|
||||
neg_max = neg_dist.max()
|
||||
# Closer to 1 is harder
|
||||
pos_dist = pos_dist / pos_max * same_label_factor
|
||||
neg_dist = 1 * labels_distinct - (neg_dist / neg_max)
|
||||
return pos_dist + neg_dist
|
||||
|
||||
|
||||
def sort_batches(inputs, labels, masks, embeddings, device, mini_batch_size=32,
|
||||
scheduler: Optional[BatchingScheduler] = None):
|
||||
start = datetime.now()
|
||||
|
||||
same_label_factor = scheduler.get_scaling_same_label_factor() if scheduler else 1
|
||||
scaled_dist = get_scaled_distances(embeddings, labels, device, same_label_factor)
|
||||
# Get vector of (row, column, dist)
|
||||
dist_list = get_dist_tuple_list(scaled_dist)
|
||||
|
||||
dist_list = dist_list.cpu().detach().numpy()
|
||||
# Sort distances descending by last row
|
||||
sorted_dists = dist_list[:, dist_list[-1, :].argsort()[::-1]]
|
||||
|
||||
# Loop through list assigning both items to same group
|
||||
dist_threshold = scheduler.get_dist_threshold() if scheduler else 0.5
|
||||
grouper = BatchGrouper(sorted_dists, total_items=labels.size()[0], mini_batch_size=mini_batch_size,
|
||||
dist_threshold=dist_threshold)
|
||||
indices = torch.tensor(grouper.cluster_items()).type(torch.IntTensor)
|
||||
final_inputs = torch.index_select(inputs, dim=0, index=indices)
|
||||
final_labels = torch.index_select(labels, dim=0, index=indices)
|
||||
final_masks = torch.index_select(masks, dim=0, index=indices)
|
||||
|
||||
logger.info(f"Batch sorting took: {datetime.now() - start}")
|
||||
return final_inputs, final_labels, final_masks
|
||||
62
training/batching_scheduler.py
Normal file
62
training/batching_scheduler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from collections import deque
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BatchingScheduler():
|
||||
""" This class acts as scheduler for the batching algorithm
|
||||
"""
|
||||
|
||||
def __init__(self, decay_factor=0.8, min_threshold=0.2, triplets_threshold=10, cooldown=10) -> None:
|
||||
# internal vars
|
||||
self._step_count = 0
|
||||
self._dist_threshold = 0.5
|
||||
self._last_used_triplets = deque([], 5)
|
||||
self._scaling_same_label_factor = 1
|
||||
self._last_update_step = -10
|
||||
|
||||
# Parameters
|
||||
self.decay_factor = decay_factor
|
||||
self.min_threshold = min_threshold
|
||||
self.triplets_threshold = triplets_threshold
|
||||
self.cooldown = cooldown
|
||||
|
||||
def state_dict(self):
|
||||
"""Returns the state of the scheduler as a :class:`dict`.
|
||||
"""
|
||||
return {key: value for key, value in self.__dict__.items()}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Loads the schedulers state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): scheduler state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
def step(self, used_triplets):
|
||||
self._step_count += 1
|
||||
self._last_used_triplets.append(used_triplets)
|
||||
if (np.mean(self._last_used_triplets) < self.triplets_threshold and
|
||||
self._last_update_step + self.cooldown <= self._step_count):
|
||||
if self._dist_threshold > self.min_threshold:
|
||||
print(f"Updating dist_threshold at {self._step_count} ({np.mean(self._last_used_triplets)})")
|
||||
self.update_dist_threshold()
|
||||
if self._scaling_same_label_factor > 0.6:
|
||||
print(f"Updating scale factor at {self._step_count} ({np.mean(self._last_used_triplets)})")
|
||||
self.update_scale_factor()
|
||||
self._last_update_step = self._step_count
|
||||
|
||||
def update_scale_factor(self):
|
||||
self._scaling_same_label_factor = max(self._scaling_same_label_factor * 0.9, 0.6)
|
||||
print(f"Updating scaling factor to {self._scaling_same_label_factor}")
|
||||
|
||||
def update_dist_threshold(self):
|
||||
self._dist_threshold = max(self.min_threshold, self._dist_threshold * self.decay_factor)
|
||||
print(f"Updated dist_threshold to {self._dist_threshold}")
|
||||
|
||||
def get_dist_threshold(self) -> float:
|
||||
return self._dist_threshold
|
||||
|
||||
def get_scaling_same_label_factor(self) -> float:
|
||||
return self._scaling_same_label_factor
|
||||
18
training/gaussian_noise.py
Normal file
18
training/gaussian_noise.py
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class GaussianNoise(object):
|
||||
def __init__(self, mean=0., std=1.):
|
||||
self.std = std
|
||||
self.mean = mean
|
||||
|
||||
def __call__(self, tensor):
|
||||
return tensor + torch.randn(tensor.size()) * self.std + self.mean
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
105
training/online_batch_mining.py
Normal file
105
training/online_batch_mining.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
eps = 1e-8 # an arbitrary small value to be used for numerical stability tricks
|
||||
|
||||
# Adapted from https://qdrant.tech/articles/triplet-loss/
|
||||
|
||||
|
||||
class BatchAllTripletLoss(nn.Module):
|
||||
"""Uses all valid triplets to compute Triplet loss
|
||||
Args:
|
||||
margin: Margin value in the Triplet Loss equation
|
||||
"""
|
||||
|
||||
def __init__(self, device, margin=1., filter_easy_triplets=True):
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
self.device = device
|
||||
self.filter_easy_triplets = filter_easy_triplets
|
||||
|
||||
def get_triplet_mask(self, labels):
|
||||
"""compute a mask for valid triplets
|
||||
Args:
|
||||
labels: Batch of integer labels. shape: (batch_size,)
|
||||
Returns:
|
||||
Mask tensor to indicate which triplets are actually valid. Shape: (batch_size, batch_size, batch_size)
|
||||
A triplet is valid if:
|
||||
`labels[i] == labels[j] and labels[i] != labels[k]`
|
||||
and `i`, `j`, `k` are different.
|
||||
"""
|
||||
# step 1 - get a mask for distinct indices
|
||||
|
||||
# shape: (batch_size, batch_size)
|
||||
indices_equal = torch.eye(labels.size()[0], dtype=torch.bool, device=labels.device)
|
||||
indices_not_equal = torch.logical_not(indices_equal)
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
i_not_equal_j = indices_not_equal.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
i_not_equal_k = indices_not_equal.unsqueeze(1)
|
||||
# shape: (1, batch_size, batch_size)
|
||||
j_not_equal_k = indices_not_equal.unsqueeze(0)
|
||||
# Shape: (batch_size, batch_size, batch_size)
|
||||
distinct_indices = torch.logical_and(torch.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)
|
||||
|
||||
# step 2 - get a mask for valid anchor-positive-negative triplets
|
||||
|
||||
# shape: (batch_size, batch_size)
|
||||
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
i_equal_j = labels_equal.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
i_equal_k = labels_equal.unsqueeze(1)
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
valid_indices = torch.logical_and(i_equal_j, torch.logical_not(i_equal_k))
|
||||
|
||||
# step 3 - combine two masks
|
||||
mask = torch.logical_and(distinct_indices, valid_indices)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(self, embeddings, labels, filter_easy_triplets=True):
|
||||
"""computes loss value.
|
||||
Args:
|
||||
embeddings: Batch of embeddings, e.g., output of the encoder. shape: (batch_size, embedding_dim)
|
||||
labels: Batch of integer labels associated with embeddings. shape: (batch_size,)
|
||||
Returns:
|
||||
Scalar loss value.
|
||||
"""
|
||||
# step 1 - get distance matrix
|
||||
# shape: (batch_size, batch_size)
|
||||
distance_matrix = torch.cdist(embeddings, embeddings, p=2)
|
||||
|
||||
# step 2 - compute loss values for all triplets by applying broadcasting to distance matrix
|
||||
|
||||
# shape: (batch_size, batch_size, 1)
|
||||
anchor_positive_dists = distance_matrix.unsqueeze(2)
|
||||
# shape: (batch_size, 1, batch_size)
|
||||
anchor_negative_dists = distance_matrix.unsqueeze(1)
|
||||
# get loss values for all possible n^3 triplets
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
triplet_loss = anchor_positive_dists - anchor_negative_dists + self.margin
|
||||
|
||||
# step 3 - filter out invalid or easy triplets by setting their loss values to 0
|
||||
|
||||
# shape: (batch_size, batch_size, batch_size)
|
||||
mask = self.get_triplet_mask(labels)
|
||||
valid_triplets = mask.sum()
|
||||
triplet_loss *= mask.to(self.device)
|
||||
# easy triplets have negative loss values
|
||||
triplet_loss = F.relu(triplet_loss)
|
||||
|
||||
if self.filter_easy_triplets:
|
||||
# step 4 - compute scalar loss value by averaging positive losses
|
||||
num_positive_losses = (triplet_loss > eps).float().sum()
|
||||
# We want to factor in how many triplets were used compared to batch_size (used_triplets * 3 / batch_size)
|
||||
# The effect of this should be similar to LR decay but penalizing batches with fewer hard triplets
|
||||
percent_used_factor = min(1.0, num_positive_losses * 3 / labels.size()[0])
|
||||
|
||||
triplet_loss = triplet_loss.sum() / (num_positive_losses + eps) * percent_used_factor
|
||||
return triplet_loss, valid_triplets, int(num_positive_losses)
|
||||
else:
|
||||
triplet_loss = triplet_loss.sum() / (valid_triplets + eps)
|
||||
return triplet_loss, valid_triplets, valid_triplets
|
||||
84
training/train_arguments.py
Normal file
84
training/train_arguments.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def get_default_args():
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
|
||||
parser.add_argument("--experiment_name", type=str, default="lsa_64_spoter",
|
||||
help="Name of the experiment after which the logs and plots will be named")
|
||||
parser.add_argument("--num_classes", type=int, default=100, help="Number of classes to be recognized by the model")
|
||||
parser.add_argument("--hidden_dim", type=int, default=108,
|
||||
help="Hidden dimension of the underlying Transformer model")
|
||||
parser.add_argument("--seed", type=int, default=379,
|
||||
help="Seed with which to initialize all the random components of the training")
|
||||
|
||||
# Embeddings
|
||||
parser.add_argument("--classification_model", action='store_true', default=False,
|
||||
help="Select SPOTER model to train, pass only for original classification model")
|
||||
parser.add_argument("--vector_length", type=int, default=32,
|
||||
help="Number of features used in the embedding vector")
|
||||
parser.add_argument("--epoch_iters", type=int, default=-1,
|
||||
help="Iterations per epoch while training embeddings. Will loop through dataset once if -1")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch Size during training and validation")
|
||||
parser.add_argument("--hard_triplet_mining", type=str, default=None,
|
||||
help="Strategy to select hard triplets, options [None, in_batch]")
|
||||
parser.add_argument("--triplet_loss_margin", type=float, default=1,
|
||||
help="Margin used in triplet loss margin (See documentation)")
|
||||
parser.add_argument("--normalize_embeddings", action='store_true', default=False,
|
||||
help="Normalize model output to keep vector length to one")
|
||||
parser.add_argument("--filter_easy_triplets", action='store_true', default=False,
|
||||
help="Filter easy triplets in online in batch triplets")
|
||||
|
||||
# Data
|
||||
parser.add_argument("--dataset_name", type=str, default="", help="Dataset name")
|
||||
parser.add_argument("--dataset_project", type=str, default="Sign Language Recognition", help="Dataset project name")
|
||||
parser.add_argument("--training_set_path", type=str, default="",
|
||||
help="Path to the training dataset CSV file (relative to root dataset)")
|
||||
parser.add_argument("--validation_set_path", type=str, default="", help="Path to the validation dataset CSV file")
|
||||
parser.add_argument("--dataset_loader", type=str, default="local",
|
||||
help="Dataset loader to use, options: [clearml, local]")
|
||||
|
||||
# Training hyperparameters
|
||||
parser.add_argument("--epochs", type=int, default=1300, help="Number of epochs to train the model for")
|
||||
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate for the model training")
|
||||
parser.add_argument("--dropout", type=float, default=0.1,
|
||||
help="Dropout used in transformer layer")
|
||||
parser.add_argument("--augmentations_prob", type=float, default=0.5, help="How often to use data augmentation")
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument("--save_checkpoints_every", type=int, default=-1,
|
||||
help="Determines every how many epochs the weight checkpoints are saved. If -1 only best model \
|
||||
after final epoch")
|
||||
|
||||
# Optimizer
|
||||
parser.add_argument("--optimizer", type=str, default="SGD",
|
||||
help="Optimizer used during training, options: [SGD, ADAM]")
|
||||
|
||||
# Tracker
|
||||
parser.add_argument("--tracker", type=str, default="none",
|
||||
help="Experiment tracker to use, options: [clearml, none]")
|
||||
|
||||
# Scheduler
|
||||
parser.add_argument("--scheduler_factor", type=float, default=0,
|
||||
help="Factor for the ReduceLROnPlateau scheduler")
|
||||
parser.add_argument("--scheduler_patience", type=int, default=10,
|
||||
help="Patience for the ReduceLROnPlateau scheduler")
|
||||
parser.add_argument("--scheduler_warmup", type=int, default=400,
|
||||
help="Warmup epochs before scheduler starts")
|
||||
|
||||
# Gaussian noise normalization
|
||||
parser.add_argument("--gaussian_mean", type=float, default=0, help="Mean parameter for Gaussian noise layer")
|
||||
parser.add_argument("--gaussian_std", type=float, default=0.001,
|
||||
help="Standard deviation parameter for Gaussian noise layer")
|
||||
|
||||
# Batch Sorting
|
||||
parser.add_argument("--start_mining_hard", type=int, default=None, help="On which epoch to start hard mining")
|
||||
parser.add_argument("--hard_mining_pre_batch_multipler", type=int, default=16,
|
||||
help="How many batches should be computed at once")
|
||||
parser.add_argument("--hard_mining_pre_batch_mining_count", type=int, default=5,
|
||||
help="How many times to loop through a list of computed batches")
|
||||
parser.add_argument("--hard_mining_scheduler_triplets_threshold", type=float, default=0,
|
||||
help="Enables batching grouping scheduler if > 0. Defines threshold for when to decay the \
|
||||
distance threshold of the batch sorter")
|
||||
|
||||
return parser
|
||||
71
training/train_utils.py
Normal file
71
training/train_utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import torch
|
||||
|
||||
from models import embeddings_scatter_plot, embeddings_scatter_plot_splits
|
||||
|
||||
|
||||
def train_setup(seed, experiment_name):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
g = torch.Generator()
|
||||
g.manual_seed(seed)
|
||||
return g
|
||||
|
||||
|
||||
def create_embedding_scatter_plots(tracker, model, train_loader, val_loader, device, id_to_label, epoch, model_name):
|
||||
tsne_results, labels = embeddings_scatter_plot(model, train_loader, device, id_to_label, perplexity=40, n_iter=1000)
|
||||
|
||||
df = pd.DataFrame({'x': tsne_results[:, 0],
|
||||
'y': tsne_results[:, 1],
|
||||
'label': labels})
|
||||
fig = px.scatter(df, y="y", x="x", color="label")
|
||||
|
||||
tracker.log_chart(
|
||||
title="Training Scatter Plot with Best Model: " + model_name,
|
||||
series="Scatter Plot",
|
||||
iteration=epoch,
|
||||
figure=fig
|
||||
)
|
||||
|
||||
tsne_results, labels = embeddings_scatter_plot(model, val_loader, device, id_to_label, perplexity=40, n_iter=1000)
|
||||
|
||||
df = pd.DataFrame({'x': tsne_results[:, 0],
|
||||
'y': tsne_results[:, 1],
|
||||
'label': labels})
|
||||
fig = px.scatter(df, y="y", x="x", color="label")
|
||||
|
||||
tracker.log_chart(
|
||||
title="Validation Scatter Plot with Best Model: " + model_name,
|
||||
series="Scatter Plot",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
|
||||
dataloaders = {'train': train_loader,
|
||||
'val': val_loader}
|
||||
splits = list(dataloaders.keys())
|
||||
tsne_results_splits, labels_splits = embeddings_scatter_plot_splits(model, dataloaders,
|
||||
device, id_to_label, perplexity=40, n_iter=1000)
|
||||
tsne_results = np.vstack([tsne_results_splits[split] for split in splits])
|
||||
labels = np.concatenate([labels_splits[split] for split in splits])
|
||||
split = np.concatenate([[split]*len(labels_splits[split]) for split in splits])
|
||||
df = pd.DataFrame({'x': tsne_results[:, 0],
|
||||
'y': tsne_results[:, 1],
|
||||
'label': labels,
|
||||
'split': split})
|
||||
fig = px.scatter(df, y="y", x="x", color="label", symbol='split')
|
||||
tracker.log_chart(
|
||||
title="Scatter Plot of train and val with Best Model: " + model_name,
|
||||
series="Scatter Plot",
|
||||
iteration=epoch,
|
||||
figure=fig,
|
||||
)
|
||||
Reference in New Issue
Block a user