Files
spoterembedding/datasets/czech_slr_dataset.py
Mathias Claassen 81bbf66aab Initial codebase (#1)
* Add project code

* Logger improvements

* Improvements to web demo code

* added create_wlasl_landmarks_dataset.py and xtract_mediapipe_landmarks.py

* Fix rotation augmentation

* fixed error in docstring, and removed unnecessary replace -1 -> 0

* Readme updates

* Share base notebooks

* Add notebooks and unify for different datasets

* requirements update

* fixes

* Make evaluate more deterministic

* Allow training with clearml

* refactor preprocessing and apply linter

* Minor fixes

* Minor notebook tweaks

* Readme updates

* Fix PR comments

* Remove unneeded code

* Add banner to Readme

---------

Co-authored-by: Gabriel Lema <gabriel.lema@xmartlabs.com>
2023-03-03 10:07:54 -03:00

73 lines
2.4 KiB
Python

import torch
import numpy as np
import torch.utils.data as torch_data
from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \
random_augmentation
from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict
from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict
class CzechSLRDataset(torch_data.Dataset):
"""Advanced object representation of the HPOES dataset for loading hand joints landmarks utilizing the Torch's
built-in Dataset properties"""
data: [np.ndarray]
labels: [np.ndarray]
def __init__(self, dataset_filename: str, num_labels=5, transform=None, augmentations=False,
augmentations_prob=0.5, normalize=True):
"""
Initiates the HPOESDataset with the pre-loaded data from the h5 file.
:param dataset_filename: Path to the h5 file
:param transform: Any data transformation to be applied (default: None)
"""
loaded_data = load_dataset(dataset_filename)
data, labels = loaded_data[0], loaded_data[1]
self.data = data
self.labels = labels
self.targets = list(labels)
self.num_labels = num_labels
self.transform = transform
self.augmentations = augmentations
self.augmentations_prob = augmentations_prob
self.normalize = normalize
def __getitem__(self, idx):
"""
Allocates, potentially transforms and returns the item at the desired index.
:param idx: Index of the item
:return: Tuple containing both the depth map and the label
"""
depth_map = torch.from_numpy(np.copy(self.data[idx]))
# label = torch.Tensor([self.labels[idx] - 1])
label = torch.Tensor([self.labels[idx]])
depth_map = tensor_to_dictionary(depth_map)
# Apply potential augmentations
depth_map = random_augmentation(self.augmentations, self.augmentations_prob, depth_map)
if self.normalize:
depth_map = normalize_single_body_dict(depth_map)
depth_map = normalize_single_hand_dict(depth_map)
depth_map = dictionary_to_tensor(depth_map)
# Move the landmark position interval to improve performance
depth_map = depth_map - 0.5
if self.transform:
depth_map = self.transform(depth_map)
return depth_map, label
def __len__(self):
return len(self.labels)