Rotation augmentation class added

This commit is contained in:
RobbeDeWaele
2023-03-30 16:13:03 +02:00
parent 7793122eac
commit 0af9320571
3 changed files with 65 additions and 1411 deletions

View File

@@ -124,4 +124,16 @@ class NoiseAugmentation:
def __call__(self, sample):
# add noise to the keypoints
sample = sample + torch.randn(sample.shape) * self.noise
return sample
return sample
# augmentation to rotate all keypoints around 0,0
class RotateAugmentation:
def __call__(self, sample):
# generate a random angle between -13 and 13 degrees
angle_max = 13.0
angle = math.radians(random.uniform(a=-angle_max, b=angle_max))
# rotate the keypoints around 0.0
new_sample = sample
new_sample[:, :, 0] = sample[:, :, 0]*math.cos(angle) - sample[:, :, 1]*math.sin(angle)
new_sample[:, :, 1] = sample[:, :, 0]*math.sin(angle) + sample[:, :, 1]*math.cos(angle)
return new_sample

View File

@@ -8,7 +8,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation, RotateAugmentation
from src.datasets.finger_spelling_dataset import FingerSpellingDataset
from src.identifiers import LANDMARKS
from src.model import SPOTER
@@ -29,12 +29,16 @@ def train():
g = torch.Generator()
g.manual_seed(379)
device = torch.device("cuda:0")
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) *2)
# use cuda if available
if torch.cuda.is_available():
device = torch.device("cuda:0")
else:
device = torch.device("cpu")
spoter_model.train(True)
spoter_model.to(device)
criterion = nn.CrossEntropyLoss()
criterion_bad = CustomLoss()
@@ -45,7 +49,7 @@ def train():
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1)])
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1), RotateAugmentation()])
train_set = FingerSpellingDataset("data/fingerspelling/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform)
train_loader = DataLoader(train_set, shuffle=True, generator=g)

File diff suppressed because one or more lines are too long