Resolve WES-41 "Normalization"
This commit is contained in:
committed by
Victor Mylle
parent
ba44762eba
commit
9f7197e4e9
31
export.py
Normal file
31
export.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torchvision
|
||||
import onnx
|
||||
import numpy as np
|
||||
|
||||
from src.model import SPOTER
|
||||
from src.identifiers import LANDMARKS
|
||||
|
||||
model_name = 'Fingerspelling_AE'
|
||||
|
||||
# load PyTorch model from .pth file
|
||||
model = SPOTER(num_classes=5, hidden_dim=len(LANDMARKS) *2)
|
||||
state_dict = torch.load('models/' + model_name + '.pth')
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# set model to evaluation mode
|
||||
model.eval()
|
||||
|
||||
# create dummy input tensor
|
||||
batch_size = 1
|
||||
num_of_frames = 1
|
||||
input_shape = (108, num_of_frames)
|
||||
dummy_input = torch.randn(batch_size, *input_shape)
|
||||
|
||||
# export model to ONNX format
|
||||
output_file = 'models/' + model_name + '.onnx'
|
||||
torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
# load exported ONNX model for verification
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
Reference in New Issue
Block a user