Files
spoterembedding/datasets/czech_slr_dataset.py
2023-04-07 09:44:12 +00:00

74 lines
2.4 KiB
Python

import torch
import numpy as np
import torch.utils.data as torch_data
from datasets.datasets_utils import load_dataset, tensor_to_dictionary, dictionary_to_tensor, \
random_augmentation
from normalization.body_normalization import normalize_single_dict as normalize_single_body_dict
from normalization.hand_normalization import normalize_single_dict as normalize_single_hand_dict
class CzechSLRDataset(torch_data.Dataset):
"""Advanced object representation of the HPOES dataset for loading hand joints landmarks utilizing the Torch's
built-in Dataset properties"""
data: [np.ndarray]
labels: [np.ndarray]
def __init__(self, dataset_filename: str, num_labels=5, transform=None, augmentations=False,
augmentations_prob=0.5, normalize=True):
"""
Initiates the HPOESDataset with the pre-loaded data from the h5 file.
:param dataset_filename: Path to the h5 file
:param transform: Any data transformation to be applied (default: None)
"""
loaded_data = load_dataset(dataset_filename)
data, labels = loaded_data[0], loaded_data[1]
self.data = data
self.labels = labels
self.targets = list(labels)
self.num_labels = num_labels
self.transform = transform
self.augmentations = augmentations
self.augmentations_prob = augmentations_prob
self.normalize = normalize
def __getitem__(self, idx):
"""
Allocates, potentially transforms and returns the item at the desired index.
:param idx: Index of the item
:return: Tuple containing both the depth map and the label
"""
depth_map = torch.from_numpy(np.copy(self.data[idx]))
# label = torch.Tensor([self.labels[idx] - 1])
label = torch.Tensor([self.labels[idx]])
depth_map = tensor_to_dictionary(depth_map)
# Apply potential augmentations
depth_map = random_augmentation(self.augmentations, self.augmentations_prob, depth_map)
if self.normalize:
depth_map = normalize_single_body_dict(depth_map)
depth_map = normalize_single_hand_dict(depth_map)
depth_map = dictionary_to_tensor(depth_map)
# Move the landmark position interval to improve performance
depth_map = depth_map - 0.5
if self.transform:
depth_map = self.transform(depth_map)
return depth_map, label
def __len__(self):
return len(self.labels)