Skip to content

Instantly share code, notes, and snippets.

@dayyass
Created November 25, 2020 16:49
Show Gist options
  • Save dayyass/6d8f9f85f22a7d8e4179e18f624a652f to your computer and use it in GitHub Desktop.
Save dayyass/6d8f9f85f22a7d8e4179e18f624a652f to your computer and use it in GitHub Desktop.
ONNX doesn't support PyTorch Adaptive Pooling (and Global Pooling as a special case with output_size=1). There is an implementation of Global Pooling compatible with ONNX.
import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime
##### INIT 1d, 2d, 3d GLOBAL POOLING MODULES #####
class GlobalAvgPool1d(nn.Module):
"""
Reduce mean over last dimension.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.mean(dim=-1, keepdim=True)
class GlobalMaxPool1d(nn.Module):
"""
Reduce max over last dimension.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x.max(dim=-1, keepdim=True)[0]
class GlobalAvgPool2d(nn.Module):
"""
Reduce mean over last two dimensions.
"""
def __init__(self):
super().__init__()
def forward(self, x):
x = x.mean(dim=-1, keepdim=True)
return x.mean(dim=-2, keepdim=True)
class GlobalMaxPool2d(nn.Module):
"""
Reduce max over last two dimensions.
"""
def __init__(self):
super().__init__()
def forward(self, x):
x = x.max(dim=-1, keepdim=True)[0]
return x.max(dim=-2, keepdim=True)[0]
class GlobalAvgPool3d(nn.Module):
"""
Reduce mean over last three dimensions.
"""
def __init__(self):
super().__init__()
def forward(self, x):
x = x.mean(dim=-1, keepdim=True)
x = x.mean(dim=-2, keepdim=True)
return x.mean(dim=-3, keepdim=True)
class GlobalMaxPool3d(nn.Module):
"""
Reduce max over last three dimensions.
"""
def __init__(self):
super().__init__()
def forward(self, x):
x = x.max(dim=-1, keepdim=True)[0]
x = x.max(dim=-2, keepdim=True)[0]
return x.max(dim=-3, keepdim=True)[0]
##### EXAMPLE OF ONNX EXPORT #####
global_pooling = GlobalMaxPool2d() # init global pooling layer
# input to the global pooling layer
tensor = torch.randn(1, 1, 224, 224) # init first two dimensions with ones to allow dynamic axes
torch_out = global_pooling(tensor) # torch inference
# export the global pooling layer
torch.onnx.export(
model=global_pooling, # model being run
args=tensor, # model input (or a tuple for multiple inputs)
f="global_pooling.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={ # variable length axes
'input': {0: 'batch_size', 1: 'n_channel', 2: 'height', 3: 'width'},
'output': {0: 'batch_size', 1: 'n_channel'},
},
)
onnx_model = onnx.load('global_pooling.onnx') # load onnx global pooling layer
onnx.checker.check_model(onnx_model) # verify the model’s structure and confirm that the model has a valid schema
ort_session = onnxruntime.InferenceSession('global_pooling.onnx') # create an inference session of onnxruntime
# compute ONNX Runtime output prediction
ort_inputs = {'input': tensor.numpy()}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print('Exported model has been tested with ONNXRuntime, and the result looks good!')
# check to inference tensor with different shape
tensor = torch.randn(2, 3, 128, 256)
torch_out = global_pooling(tensor) # torch inference
ort_inputs = {'input': tensor.numpy()} # onnx input
ort_outs = ort_session.run(None, ort_inputs) # onnx inference
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05) # compare torch with onnx
print('ONNX can work with arbitrary dimension tensor!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment