Initial codebase (#1)
* Add project code * Logger improvements * Improvements to web demo code * added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py * Fix rotation augmentation * fixed error in docstring, and removed unnecessary replace -1 -> 0 * Readme updates * Share base notebooks * Add notebooks and unify for different datasets * requirements update * fixes * Make evaluate more deterministic * Allow training with clearml * refactor preprocessing and apply linter * Minor fixes * Minor notebook tweaks * Readme updates * Fix PR comments * Remove unneeded code * Add banner to Readme --------- Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
This commit is contained in:
72
datasets/czech_slr_dataset.py
Normal file
72
datasets/czech_slr_dataset.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.utils.data as torch_data
|
||||
|
||||
from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \
|
||||
random_augmentation
|
||||
from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict
|
||||
from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict
|
||||
|
||||
|
||||
class CzechSLRDataset(torch_data.Dataset):
|
||||
"""Advanced object representation of the HPOES dataset for loading hand joints landmarks utilizing the Torch's
|
||||
built-in Dataset properties"""
|
||||
|
||||
data: [np.ndarray]
|
||||
labels: [np.ndarray]
|
||||
|
||||
def __init__(self, dataset_filename: str, num_labels=5, transform=None, augmentations=False,
|
||||
augmentations_prob=0.5, normalize=True):
|
||||
"""
|
||||
Initiates the HPOESDataset with the pre-loaded data from the h5 file.
|
||||
|
||||
:param dataset_filename: Path to the h5 file
|
||||
:param transform: Any data transformation to be applied (default: None)
|
||||
"""
|
||||
|
||||
loaded_data = load_dataset(dataset_filename)
|
||||
data, labels = loaded_data[0], loaded_data[1]
|
||||
|
||||
self.data = data
|
||||
self.labels = labels
|
||||
self.targets = list(labels)
|
||||
self.num_labels = num_labels
|
||||
self.transform = transform
|
||||
|
||||
self.augmentations = augmentations
|
||||
self.augmentations_prob = augmentations_prob
|
||||
self.normalize = normalize
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Allocates, potentially transforms and returns the item at the desired index.
|
||||
|
||||
:param idx: Index of the item
|
||||
:return: Tuple containing both the depth map and the label
|
||||
"""
|
||||
|
||||
depth_map = torch.from_numpy(np.copy(self.data[idx]))
|
||||
# label = torch.Tensor([self.labels[idx] - 1])
|
||||
label = torch.Tensor([self.labels[idx]])
|
||||
|
||||
depth_map = tensor_to_dictionary(depth_map)
|
||||
|
||||
# Apply potential augmentations
|
||||
depth_map = random_augmentation(self.augmentations, self.augmentations_prob, depth_map)
|
||||
|
||||
if self.normalize:
|
||||
depth_map = normalize_single_body_dict(depth_map)
|
||||
depth_map = normalize_single_hand_dict(depth_map)
|
||||
|
||||
depth_map = dictionary_to_tensor(depth_map)
|
||||
|
||||
# Move the landmark position interval to improve performance
|
||||
depth_map = depth_map - 0.5
|
||||
|
||||
if self.transform:
|
||||
depth_map = self.transform(depth_map)
|
||||
|
||||
return depth_map, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
Reference in New Issue
Block a user