Last active
April 15, 2021 10:08
-
-
Save ij96/ff1569db094b6a906b0d79020cc11e9b to your computer and use it in GitHub Desktop.
Code template for infering ONNX model in Python with ONNXRuntime
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Code template for infering ONNX model in Python with ONNXRuntime""" | |
import numpy as np | |
import onnxruntime as ort | |
import time | |
onnx_model_path = 'path/to/onnx/model.onnx' | |
# run the model on the ORT backend | |
session = ort.InferenceSession(onnx_model_path, None) | |
# get the name of the first input of the model | |
for i, item in enumerate(session.get_inputs()): | |
print(f'Input {i}:') | |
print(f' Name = {item.name}') | |
print(f' Shape = {item.shape}') | |
print(f' Type = {item.type}') | |
for i, item in enumerate(session.get_outputs()): | |
print(f'Output {i}:') | |
print(f' Name = {item.name}') | |
print(f' Shape = {item.shape}') | |
print(f' Type = {item.type}') | |
# create dummy data for inference | |
input_shape = (100, 128, 128, 1) | |
input_data = np.random.randn(*input_shape).astype(np.float32) | |
# inference - don't time the first run | |
_ = session.run([], dict({'input': input_data})) | |
start = time.time() | |
prediction = session.run([], dict({'input': input_data})) | |
end = time.time() | |
print(f'Inference time:{end - start:.4f}s') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment