Created
March 19, 2022 23:43
-
-
Save atksh/373d1b9390903ee887e7d13815b9be4f to your computer and use it in GitHub Desktop.
onnx optimization w/ quantize
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import shutil | |
import tempfile | |
import onnx | |
import onnxruntime as rt | |
from onnxruntime.quantization import QuantType, quantize_dynamic | |
def opt_by_rt(input_model: str, output_model: str): | |
sess_options = rt.SessionOptions() | |
sess_options.graph_optimization_level = ( | |
rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED | |
) | |
sess_options.optimized_model_filepath = output_model | |
session = rt.InferenceSession(input_model, sess_options) | |
def remove_initializer_from_input(input_model: str, output_model: str): | |
model = onnx.load(input_model) | |
inputs = model.graph.input | |
name_to_input = {} | |
for input in inputs: | |
name_to_input[input.name] = input | |
for initializer in model.graph.initializer: | |
if initializer.name in name_to_input: | |
inputs.remove(name_to_input[initializer.name]) | |
onnx.save(model, output_model) | |
def quantize(input_model: str, output_model: str): | |
with tempfile.TemporaryDirectory() as tmpdir: | |
input_path = os.path.join(tmpdir, "input.onnx") | |
output_path = os.path.join(tmpdir, "output.onnx") | |
shutil.copyfile(input_model, input_path) | |
quantize_dynamic( | |
model_input=input_path, | |
model_output=output_path, | |
per_channel=True, | |
activation_type=QuantType.QUInt8, | |
weight_type=QuantType.QUInt8, | |
optimize_model=True, | |
) | |
shutil.copyfile(output_path, output_model) | |
def optimize(input_model: str, output_model: str, use_quantize: bool = False): | |
remove_initializer_from_input(input_model, output_model) | |
if use_quantize: | |
quantize(output_model, output_model) | |
opt_by_rt(output_model, output_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment