Some fixes
This commit is contained in:
66
export_model.py
Normal file
66
export_model.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from models.spoter_embedding_model import SPOTER_EMBEDDINGS
|
||||
|
||||
# set parameters of the model
|
||||
model_name = 'embedding_model'
|
||||
output=32
|
||||
|
||||
# load PyTorch model from .pth file
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
|
||||
CHECKPOINT_PATH = "out-checkpoints/augment_rotate_75_x8/checkpoint_embed_1105.pth"
|
||||
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
|
||||
|
||||
model = SPOTER_EMBEDDINGS(
|
||||
features=checkpoint["config_args"].vector_length,
|
||||
hidden_dim=checkpoint["config_args"].hidden_dim,
|
||||
norm_emb=checkpoint["config_args"].normalize_embeddings,
|
||||
).to(device)
|
||||
model.load_state_dict(checkpoint["state_dict"])
|
||||
# set model to evaluation mode
|
||||
model.eval()
|
||||
|
||||
model_export = "onnx"
|
||||
if model_export == "coreml":
|
||||
dummy_input = torch.randn(1, 10, 54, 2)
|
||||
traced_model = torch.jit.trace(model, dummy_input)
|
||||
out = traced_model(dummy_input)
|
||||
import coremltools as ct
|
||||
|
||||
# Convert to Core ML
|
||||
coreml_model = ct.convert(
|
||||
traced_model,
|
||||
inputs=[ct.TensorType(name="input", shape=dummy_input.shape)],
|
||||
)
|
||||
|
||||
# Save Core ML model
|
||||
coreml_model.save("models/" + model_name + ".mlmodel")
|
||||
else:
|
||||
# create dummy input tensor
|
||||
dummy_input = torch.randn(1, 10, 54, 2)
|
||||
|
||||
# export model to ONNX format
|
||||
output_file = 'models/' + model_name + '.onnx'
|
||||
torch.onnx.export(model, dummy_input, output_file, input_names=['input'], output_names=['output'])
|
||||
|
||||
torch.onnx.export(model, # model being run
|
||||
dummy_input, # model input (or a tuple for multiple inputs)
|
||||
'output-models/' + model_name + '.onnx', # where to save the model (can be a file or file-like object)
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=9, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names = ['X'], # the model's input names
|
||||
output_names = ['Y'] # the model's output names
|
||||
)
|
||||
|
||||
|
||||
# load exported ONNX model for verification
|
||||
onnx_model = onnx.load(output_file)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
Reference in New Issue
Block a user