Skip to content

Instantly share code, notes, and snippets.

@gglin001
Last active July 12, 2023 04:59
Show Gist options
  • Save gglin001/94043b211ed2946ff93e04cc8f4823b1 to your computer and use it in GitHub Desktop.
Save gglin001/94043b211ed2946ff93e04cc8f4823b1 to your computer and use it in GitHub Desktop.
encode & decode onnx model to qr code
import argparse
import binascii
import tarfile
from io import BytesIO
import qrcode
parser = argparse.ArgumentParser("xx")
parser.add_argument("--input_file", type=str, required=True, help="input file")
args = parser.parse_args()
fp = args.input_file
with open(fp, "rb") as fh:
onnx_buffer = BytesIO(fh.read())
onnx_buffer.seek(0)
# save to tar buffer
tar_buffer = BytesIO()
# mode = 'w:gz'
mode = 'w:xz'
# mode = 'w:bz2'
tar = tarfile.open(fileobj=tar_buffer, mode=mode)
tarinfo = tarfile.TarInfo(fp)
tarinfo.size = len(onnx_buffer.getvalue())
tar.addfile(tarinfo, onnx_buffer)
tar.close()
tar_buffer.seek(0)
tar_content = tar_buffer.read()
xx_ = binascii.hexlify(tar_content)
print(f"{mode}: {len(xx_)}")
# debug
# b = binascii.a2b_hex(xx_)
MAX_LEN = 2300
for i in range(0, len(xx_), MAX_LEN):
xx = xx_[i : i + MAX_LEN]
fp = f'qr_{i}.png'
print(f"encode to {fp} , len {len(xx)}")
qrcode.make(xx, version=40).save(fp)
import argparse
import binascii
import tarfile
import numpy as np
from io import BytesIO
import onnx
import qrcode
def clean_onnx(fp):
onnx_model = onnx.load_model(fp)
for init in onnx_model.graph.initializer:
num_elements = np.prod(init.dims)
if num_elements > 10:
init.ClearField('int32_data')
init.ClearField('int64_data')
init.ClearField('float_data')
init.ClearField('double_data')
init.ClearField('raw_data')
return onnx_model
def clean_onnx_no_init(fp):
onnx_model = onnx.load_model(fp)
init = onnx_model.graph.initializer
# TODO: keep axis/step/.. param for Gather/Slice and other Nodes
for idx in range(len(init)):
init.pop()
return onnx_model
parser = argparse.ArgumentParser("xx")
parser.add_argument("--input_file", type=str, required=True, help="input file")
args = parser.parse_args()
# test
# wget https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-v1-7.onnx
# fp = 'resnet50-v1-7.onnx'
fp = args.input_file
onnx_model = clean_onnx(fp)
# save onnx to buffer
onnx_buffer = BytesIO()
onnx.save_model(onnx_model, onnx_buffer)
onnx_buffer.seek(0)
# save to tar buffer
tar_buffer = BytesIO()
# mode = 'w:gz'
mode = 'w:xz'
# mode = 'w:bz2'
tar = tarfile.open(fileobj=tar_buffer, mode=mode)
tarinfo = tarfile.TarInfo(f"{fp}.clean.onnx")
tarinfo.size = len(onnx_buffer.getvalue())
tar.addfile(tarinfo, onnx_buffer)
tar.close()
tar_buffer.seek(0)
tar_content = tar_buffer.read()
xx_ = binascii.hexlify(tar_content)
print(f"{mode}: {len(xx_)}")
# debug
# b = binascii.a2b_hex(xx_)
MAX_LEN = 2300
for i in range(0, len(xx_), MAX_LEN):
xx = xx_[i : i + MAX_LEN]
fp = f'qr_{i}.png'
print(f"encode to {fp} , len {len(xx)}")
qrcode.make(xx, version=40).save(fp)
import binascii
import glob
import os
import tarfile
from io import BytesIO
from PIL import Image
from pyzbar.pyzbar import decode
fps = glob.glob('qr_*.png')
fps = sorted(fps, key=lambda x: int(x.split('_')[1].split('.')[0]))
bytes = b''
for fp in fps:
x = decode(Image.open(fp))
xxx = x[0].data
print(f"dencode from {fp} , len {len(x[0].data)}")
bytes += xxx
he = binascii.a2b_hex(bytes)
tar_buffer = BytesIO()
tar_buffer.write(he)
tar_buffer.seek(0)
# mode = 'r:gz'
mode = 'r:xz'
# mode = 'r:bz2'
tar = tarfile.open(fileobj=tar_buffer, mode=mode)
tar.extractall('.')
import onnx
import numpy as np
import onnx.helper
import onnx.numpy_helper
import numpy as np
np.random.seed(1984)
def fill_onnx(fp):
onnx_model = onnx.load_model(fp)
for init in onnx_model.graph.initializer:
num_elements = np.prod(init.dims)
if num_elements > 10:
print(f"{init.name}: {init.dims}")
np_dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type)
array = np.random.uniform(-1, 1, size=init.dims).astype(np_dtype)
init.raw_data = array.tobytes()
return onnx_model
fp = 'clean.onnx'
onnx_model = fill_onnx(fp)
fp_restored = f"{fp}.restored.onnx"
onnx.save(onnx_model, fp_restored)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment