Last active
May 4, 2023 05:51
-
-
Save usstq/610071c2009f7fbcb7806f28c4a82370 to your computer and use it in GitHub Desktop.
a CLI tool for inspecting onnx model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import onnx | |
import onnx.numpy_helper | |
import sys | |
def get_value_info(m, name): | |
if (name.startswith('%')): | |
name = name[1:] | |
for a in m.graph.value_info: | |
if a.name == name: | |
return onnx.helper.printable_value_info(a) | |
return name | |
def read_attr_value(m, fuzzy_name): | |
if (fuzzy_name.startswith('%')): | |
fuzzy_name = fuzzy_name[1:] | |
for a in m.graph.node: | |
if fuzzy_name in a.output: | |
name = a.name | |
for o_idx in range(len(a.output)): | |
if (a.output[o_idx] == fuzzy_name): | |
value = None | |
for v in a.attribute: | |
if v.name == 'value': | |
value = onnx.numpy_helper.to_array(v.t) | |
return f"{value}" | |
return "?" | |
def read_const(m, fuzzy_name): | |
if (fuzzy_name.startswith('%')): | |
fuzzy_name = fuzzy_name[1:] | |
for a in m.graph.value_info: | |
if a.name == fuzzy_name: | |
#print(dir(a.type)) | |
print(f"value_info: {onnx.helper.printable_value_info(a)}") | |
#print(f"value_info: {a}") | |
for a in m.graph.node: | |
if a.name == fuzzy_name: | |
print(f"{a}") | |
return | |
if fuzzy_name in a.output: | |
name = a.name | |
for o_idx in range(len(a.output)): | |
if (a.output[o_idx] == fuzzy_name): | |
value = None | |
for v in a.attribute: | |
if v.name == 'value': | |
value = onnx.numpy_helper.to_array(v.t) | |
print(f"value: {name}:{o_idx}({fuzzy_name}) {value}") | |
return | |
for t in m.graph.initializer: | |
if t.name == fuzzy_name: | |
tensor = onnx.numpy_helper.to_array(t) | |
print(f"initializer: {fuzzy_name}={tensor} \n{tensor.dtype} {tensor.shape}") | |
if (len(sys.argv) == 1): | |
print("inspect.py onnx_model [node_name | output_name | initialzer_name] ...") | |
print(" You can check printable graph or any constant/initializer") | |
sys.exit(0) | |
# show printable graph | |
m = onnx.load(sys.argv[1]) | |
if (len(sys.argv) == 2): | |
m = onnx.shape_inference.infer_shapes(m) | |
str_graph = onnx.helper.printable_graph(m.graph) | |
for line in str_graph.splitlines(): | |
parts = line.split(" = ") | |
value = "" | |
if len(parts) >= 2: | |
value_name = parts[0].strip() | |
value_start = parts[0].find(value_name) | |
parts[0] = parts[0][0:value_start] + get_value_info(m, value_name).strip() | |
if parts[1].startswith("Constant"): | |
value = f"value={read_attr_value(m, value_name)}" | |
print(" = ".join(parts) + " " + value) | |
sys.exit(0) | |
# inspect value | |
for n in sys.argv[2:]: | |
read_const(m, n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment