Skip to content

Instantly share code, notes, and snippets.

@atksh
Created March 19, 2022 23:43
Show Gist options
  • Save atksh/373d1b9390903ee887e7d13815b9be4f to your computer and use it in GitHub Desktop.
Save atksh/373d1b9390903ee887e7d13815b9be4f to your computer and use it in GitHub Desktop.
onnx optimization w/ quantize
import os
import shutil
import tempfile
import onnx
import onnxruntime as rt
from onnxruntime.quantization import QuantType, quantize_dynamic
def opt_by_rt(input_model: str, output_model: str):
sess_options = rt.SessionOptions()
sess_options.graph_optimization_level = (
rt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
)
sess_options.optimized_model_filepath = output_model
session = rt.InferenceSession(input_model, sess_options)
def remove_initializer_from_input(input_model: str, output_model: str):
model = onnx.load(input_model)
inputs = model.graph.input
name_to_input = {}
for input in inputs:
name_to_input[input.name] = input
for initializer in model.graph.initializer:
if initializer.name in name_to_input:
inputs.remove(name_to_input[initializer.name])
onnx.save(model, output_model)
def quantize(input_model: str, output_model: str):
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "input.onnx")
output_path = os.path.join(tmpdir, "output.onnx")
shutil.copyfile(input_model, input_path)
quantize_dynamic(
model_input=input_path,
model_output=output_path,
per_channel=True,
activation_type=QuantType.QUInt8,
weight_type=QuantType.QUInt8,
optimize_model=True,
)
shutil.copyfile(output_path, output_model)
def optimize(input_model: str, output_model: str, use_quantize: bool = False):
remove_initializer_from_input(input_model, output_model)
if use_quantize:
quantize(output_model, output_model)
opt_by_rt(output_model, output_model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment