Initial codebase (#1)
* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
This commit is contained in:
287
train.py
Normal file
287
train.py
Normal file
@@ -0,0 +1,287 @@
|
||||
|
||||
from datetime import datetime
|
||||
import os
|
||||
import os.path as op
|
||||
import argparse
|
||||
import json
|
||||
from datasets.dataset_loader import LocalDatasetLoader
|
||||
from tracking.tracker import Tracker
|
||||
import torch
|
||||
import multiprocessing
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
# import matplotlib.pyplot as plt
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from pathlib import Path
|
||||
import copy
|
||||
|
||||
from datasets import CzechSLRDataset, SLREmbeddingDataset, collate_fn_triplet_padd, collate_fn_padd
|
||||
from models import SPOTER, SPOTER_EMBEDDINGS, train_epoch, evaluate, train_epoch_embedding, \
|
||||
train_epoch_embedding_online, evaluate_embedding
|
||||
from training.online_batch_mining import BatchAllTripletLoss
|
||||
from training.batching_scheduler import BatchingScheduler
|
||||
from training.gaussian_noise import GaussianNoise
|
||||
from training.train_utils import train_setup, create_embedding_scatter_plots
|
||||
from training.train_arguments import get_default_args
|
||||
from utils import get_logger
|
||||
try:
|
||||
# Needed for argparse patching in case clearml is used
|
||||
import clearml # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
PROJECT_NAME = "spoter"
|
||||
CLEARML = "clearml"
|
||||
|
||||
|
||||
def is_pre_batch_sorting_enabled(args):
|
||||
return args.start_mining_hard is not None and args.start_mining_hard > 0
|
||||
|
||||
|
||||
def get_tracker(tracker_name, project, experiment_name):
|
||||
if tracker_name == CLEARML:
|
||||
from tracking.clearml_tracker import ClearMLTracker
|
||||
return ClearMLTracker(project_name=project, experiment_name=experiment_name)
|
||||
else:
|
||||
return Tracker(project_name=project, experiment_name=experiment_name)
|
||||
|
||||
|
||||
def get_dataset_loader(loader_name):
|
||||
if loader_name == CLEARML:
|
||||
from datasets.clearml_dataset_loader import ClearMLDatasetLoader
|
||||
return ClearMLDatasetLoader()
|
||||
else:
|
||||
return LocalDatasetLoader()
|
||||
|
||||
|
||||
def build_data_loader(dataset, batch_size, shuffle, collate_fn, generator):
|
||||
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn,
|
||||
generator=generator, pin_memory=torch.cuda.is_available(), num_workers=multiprocessing.cpu_count())
|
||||
|
||||
|
||||
def train(args, tracker: Tracker):
|
||||
tracker.execute_remotely(queue_name="default")
|
||||
# Initialize all the random seeds
|
||||
gen = train_setup(args.seed, args.experiment_name)
|
||||
os.environ['EXPERIMENT_NAME'] = args.experiment_name
|
||||
logger = get_logger(args.experiment_name)
|
||||
|
||||
# Set device to CUDA only if applicable
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Construct the model
|
||||
if not args.classification_model:
|
||||
slrt_model = SPOTER_EMBEDDINGS(
|
||||
features=args.vector_length,
|
||||
hidden_dim=args.hidden_dim,
|
||||
norm_emb=args.normalize_embeddings,
|
||||
dropout=args.dropout
|
||||
)
|
||||
model_type = 'embed'
|
||||
if args.hard_triplet_mining == "None":
|
||||
cel_criterion = nn.TripletMarginLoss(margin=args.triplet_loss_margin, p=2)
|
||||
elif args.hard_triplet_mining == "in_batch":
|
||||
cel_criterion = BatchAllTripletLoss(
|
||||
device=device,
|
||||
margin=args.triplet_loss_margin,
|
||||
filter_easy_triplets=bool(args.filter_easy_triplets)
|
||||
)
|
||||
else:
|
||||
slrt_model = SPOTER(num_classes=args.num_classes, hidden_dim=args.hidden_dim)
|
||||
model_type = 'classif'
|
||||
cel_criterion = nn.CrossEntropyLoss()
|
||||
slrt_model.to(device)
|
||||
|
||||
if args.optimizer == "SGD":
|
||||
optimizer = optim.SGD(slrt_model.parameters(), lr=args.lr)
|
||||
elif args.optimizer == "ADAM":
|
||||
optimizer = optim.Adam(slrt_model.parameters(), lr=args.lr)
|
||||
|
||||
if args.scheduler_factor > 0:
|
||||
mode = 'min' if args.classification_model else 'max'
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer,
|
||||
mode=mode,
|
||||
factor=args.scheduler_factor,
|
||||
patience=args.scheduler_patience
|
||||
)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
if args.hard_mining_scheduler_triplets_threshold > 0:
|
||||
batching_scheduler = BatchingScheduler(triplets_threshold=args.hard_mining_scheduler_triplets_threshold)
|
||||
else:
|
||||
batching_scheduler = None
|
||||
|
||||
# Ensure that the path for checkpointing and for images both exist
|
||||
Path("out-checkpoints/" + args.experiment_name + "/").mkdir(parents=True, exist_ok=True)
|
||||
Path("out-img/").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Training set
|
||||
transform = transforms.Compose([GaussianNoise(args.gaussian_mean, args.gaussian_std)])
|
||||
dataset_loader = get_dataset_loader(args.dataset_loader)
|
||||
dataset_folder = dataset_loader.get_dataset_folder(args.dataset_project, args.dataset_name)
|
||||
training_set_path = op.join(dataset_folder, args.training_set_path)
|
||||
|
||||
with open(op.join(dataset_folder, 'id_to_label.json')) as fid:
|
||||
id_to_label = json.load(fid)
|
||||
id_to_label = {int(key): value for key, value in id_to_label.items()}
|
||||
|
||||
if not args.classification_model:
|
||||
batch_size = args.batch_size
|
||||
val_batch_size = args.batch_size
|
||||
if args.hard_triplet_mining == "None":
|
||||
train_set = SLREmbeddingDataset(training_set_path, triplet=True, transform=transform, augmentations=True,
|
||||
augmentations_prob=args.augmentations_prob)
|
||||
collate_fn_train = collate_fn_triplet_padd
|
||||
elif args.hard_triplet_mining == "in_batch":
|
||||
train_set = SLREmbeddingDataset(training_set_path, triplet=False, transform=transform, augmentations=True,
|
||||
augmentations_prob=args.augmentations_prob)
|
||||
collate_fn_train = collate_fn_padd
|
||||
if is_pre_batch_sorting_enabled(args):
|
||||
batch_size *= args.hard_mining_pre_batch_multipler
|
||||
train_val_set = SLREmbeddingDataset(training_set_path, triplet=False)
|
||||
# Train dataloader for validation
|
||||
train_val_loader = build_data_loader(train_val_set, val_batch_size, False, collate_fn_padd, gen)
|
||||
else:
|
||||
train_set = CzechSLRDataset(training_set_path, transform=transform, augmentations=True)
|
||||
batch_size = 1
|
||||
val_batch_size = 1
|
||||
collate_fn_train = None
|
||||
|
||||
train_loader = build_data_loader(train_set, batch_size, True, collate_fn_train, gen)
|
||||
|
||||
# Validation set
|
||||
validation_set_path = op.join(dataset_folder, args.validation_set_path)
|
||||
|
||||
if args.classification_model:
|
||||
val_set = CzechSLRDataset(validation_set_path)
|
||||
collate_fn_val = None
|
||||
else:
|
||||
val_set = SLREmbeddingDataset(validation_set_path, triplet=False)
|
||||
collate_fn_val = collate_fn_padd
|
||||
|
||||
val_loader = build_data_loader(val_set, val_batch_size, False, collate_fn_val, gen)
|
||||
|
||||
# MARK: TRAINING
|
||||
train_acc, val_acc = 0, 0
|
||||
losses, train_accs, val_accs = [], [], []
|
||||
lr_progress = []
|
||||
top_val_acc = -999
|
||||
top_model_saved = True
|
||||
|
||||
logger.info("Starting " + args.experiment_name + "...\n\n")
|
||||
|
||||
if is_pre_batch_sorting_enabled(args):
|
||||
mini_batch_size = int(batch_size / args.hard_mining_pre_batch_multipler)
|
||||
else:
|
||||
mini_batch_size = None
|
||||
enable_batch_sorting = False
|
||||
pre_batch_mining_count = 1
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
start_time = datetime.now()
|
||||
if not args.classification_model:
|
||||
train_kwargs = {"model": slrt_model,
|
||||
"epoch_iters": args.epoch_iters,
|
||||
"train_loader": train_loader,
|
||||
"val_loader": val_loader,
|
||||
"criterion": cel_criterion,
|
||||
"optimizer": optimizer,
|
||||
"device": device,
|
||||
"scheduler": scheduler if epoch >= args.scheduler_warmup else None,
|
||||
}
|
||||
if args.hard_triplet_mining == "None":
|
||||
train_loss, val_silhouette_coef = train_epoch_embedding(**train_kwargs)
|
||||
elif args.hard_triplet_mining == "in_batch":
|
||||
if epoch == args.start_mining_hard:
|
||||
enable_batch_sorting = True
|
||||
pre_batch_mining_count = args.hard_mining_pre_batch_mining_count
|
||||
train_kwargs.update(dict(enable_batch_sorting=enable_batch_sorting,
|
||||
mini_batch_size=mini_batch_size,
|
||||
pre_batch_mining_count=pre_batch_mining_count,
|
||||
batching_scheduler=batching_scheduler if enable_batch_sorting else None))
|
||||
|
||||
train_loss, val_silhouette_coef, triplets_stats = train_epoch_embedding_online(**train_kwargs)
|
||||
|
||||
tracker.log_scalar_metric("triplets", "valid_triplets", epoch, triplets_stats["valid_triplets"])
|
||||
tracker.log_scalar_metric("triplets", "used_triplets", epoch, triplets_stats["used_triplets"])
|
||||
tracker.log_scalar_metric("triplets_pct", "pct_used", epoch, triplets_stats["pct_used"])
|
||||
tracker.log_scalar_metric("train_loss", "loss", epoch, train_loss)
|
||||
losses.append(train_loss)
|
||||
|
||||
# calculate acc on train dataset
|
||||
silhouette_coefficient_train = evaluate_embedding(slrt_model, train_val_loader, device)
|
||||
|
||||
tracker.log_scalar_metric("silhouette_coefficient", "train", epoch, silhouette_coefficient_train)
|
||||
train_accs.append(silhouette_coefficient_train)
|
||||
|
||||
val_accs.append(val_silhouette_coef)
|
||||
tracker.log_scalar_metric("silhouette_coefficient", "val", epoch, val_silhouette_coef)
|
||||
|
||||
else:
|
||||
train_loss, _, _, train_acc = train_epoch(slrt_model, train_loader, cel_criterion, optimizer, device)
|
||||
tracker.log_scalar_metric("train_loss", "loss", epoch, train_loss)
|
||||
tracker.log_scalar_metric("acc", "train", epoch, train_acc)
|
||||
losses.append(train_loss)
|
||||
train_accs.append(train_acc)
|
||||
|
||||
_, _, val_acc = evaluate(slrt_model, val_loader, device)
|
||||
val_accs.append(val_acc)
|
||||
tracker.log_scalar_metric("acc", "val", epoch, val_acc)
|
||||
|
||||
logger.info(f"Epoch time: {datetime.now() - start_time}")
|
||||
logger.info("[" + str(epoch) + "] TRAIN loss: " + str(train_loss) + " acc: " + str(train_accs[-1]))
|
||||
logger.info("[" + str(epoch) + "] VALIDATION acc: " + str(val_accs[-1]))
|
||||
|
||||
lr_progress.append(optimizer.param_groups[0]["lr"])
|
||||
tracker.log_scalar_metric("lr", "lr", epoch, lr_progress[-1])
|
||||
|
||||
if val_accs[-1] > top_val_acc:
|
||||
top_val_acc = val_accs[-1]
|
||||
top_model_name = "checkpoint_" + model_type + "_" + str(epoch) + ".pth"
|
||||
top_model_dict = {
|
||||
"name": top_model_name,
|
||||
"epoch": epoch,
|
||||
"val_acc": val_accs[-1],
|
||||
"config_args": args,
|
||||
"state_dict": copy.deepcopy(slrt_model.state_dict()),
|
||||
}
|
||||
top_model_saved = False
|
||||
|
||||
# Save checkpoint if it is the best on validation and delete previous checkpoints
|
||||
if args.save_checkpoints_every > 0 and epoch % args.save_checkpoints_every == 0 and not top_model_saved:
|
||||
torch.save(
|
||||
top_model_dict,
|
||||
"out-checkpoints/" + args.experiment_name + "/" + top_model_name
|
||||
)
|
||||
top_model_saved = True
|
||||
logger.info("Saved new best checkpoint: " + top_model_name)
|
||||
|
||||
# save top model if checkpoints are disabled
|
||||
if not top_model_saved:
|
||||
torch.save(
|
||||
top_model_dict,
|
||||
"out-checkpoints/" + args.experiment_name + "/" + top_model_name
|
||||
)
|
||||
logger.info("Saved new best checkpoint: " + top_model_name)
|
||||
|
||||
# Log scatter plots
|
||||
if not args.classification_model and args.hard_triplet_mining == "in_batch":
|
||||
logger.info("Generating Scatter Plot.")
|
||||
best_model = slrt_model
|
||||
best_model.load_state_dict(top_model_dict["state_dict"])
|
||||
create_embedding_scatter_plots(tracker, best_model, train_loader, val_loader, device, id_to_label, epoch,
|
||||
top_model_name)
|
||||
logger.info("The experiment is finished.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser("", parents=[get_default_args()], add_help=False)
|
||||
args = parser.parse_args()
|
||||
tracker = get_tracker(args.tracker, PROJECT_NAME, args.experiment_name)
|
||||
train(args, tracker)
|
||||
tracker.finish_run()
|
||||
Reference in New Issue
Block a user