Rotation augmentation class added
This commit is contained in:
@@ -124,4 +124,16 @@ class NoiseAugmentation:
|
||||
def __call__(self, sample):
|
||||
# add noise to the keypoints
|
||||
sample = sample + torch.randn(sample.shape) * self.noise
|
||||
return sample
|
||||
return sample
|
||||
|
||||
# augmentation to rotate all keypoints around 0,0
|
||||
class RotateAugmentation:
|
||||
def __call__(self, sample):
|
||||
# generate a random angle between -13 and 13 degrees
|
||||
angle_max = 13.0
|
||||
angle = math.radians(random.uniform(a=-angle_max, b=angle_max))
|
||||
# rotate the keypoints around 0.0
|
||||
new_sample = sample
|
||||
new_sample[:, :, 0] = sample[:, :, 0]*math.cos(angle) - sample[:, :, 1]*math.sin(angle)
|
||||
new_sample[:, :, 1] = sample[:, :, 0]*math.sin(angle) + sample[:, :, 1]*math.cos(angle)
|
||||
return new_sample
|
||||
14
src/train.py
14
src/train.py
@@ -8,7 +8,7 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation
|
||||
from src.augmentations import MirrorKeypoints, Z_augmentation, NoiseAugmentation, RotateAugmentation
|
||||
from src.datasets.finger_spelling_dataset import FingerSpellingDataset
|
||||
from src.identifiers import LANDMARKS
|
||||
from src.model import SPOTER
|
||||
@@ -29,12 +29,16 @@ def train():
|
||||
g = torch.Generator()
|
||||
g.manual_seed(379)
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
spoter_model = SPOTER(num_classes=26, hidden_dim=len(LANDMARKS) *2)
|
||||
|
||||
# use cuda if available
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
spoter_model.train(True)
|
||||
spoter_model.to(device)
|
||||
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
criterion_bad = CustomLoss()
|
||||
@@ -45,7 +49,7 @@ def train():
|
||||
if not os.path.exists("checkpoints"):
|
||||
os.makedirs("checkpoints")
|
||||
|
||||
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1)])
|
||||
transform = transforms.Compose([MirrorKeypoints(), NoiseAugmentation(noise=0.1), RotateAugmentation()])
|
||||
|
||||
train_set = FingerSpellingDataset("data/fingerspelling/data/", bad_data_folder="", keypoints_identifier=LANDMARKS, subset="train", transform=transform)
|
||||
train_loader = DataLoader(train_set, shuffle=True, generator=g)
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user