Last active
October 20, 2023 11:28
-
-
Save twmht/d719f7e4b3ccaa8b269b7dd7475e023d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# from tvm.contrib.torch import optimize_torch | |
import tvm.tir.tensor_intrin | |
import contextlib | |
import tempfile | |
import tvm | |
import onnx | |
from tvm import meta_schedule as ms | |
from tvm import relay | |
def get_network(weight, batch_size, layout="NHWC", dtype="float32", use_sparse=False): | |
"""Get the symbol definition and random weight of a network""" | |
input_shape = (batch_size, 3, 224, 224) | |
onnx_model = onnx.load(weight) | |
input_name = "input" | |
shape_dict = {input_name: input_shape} | |
mod, params = relay.frontend.from_onnx(onnx_model, shape_dict) | |
desired_layouts = {'nn.conv2d': ['NHWC', 'default'], 'image.resize2d': ['NHWC'], 'nn.upsampling': ['NHWC']} | |
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), relay.transform.ConvertLayout(desired_layouts)]) | |
with tvm.transform.PassContext(opt_level=3): | |
mod = seq(mod) | |
mod = tvm.IRModule.from_expr(mod["main"]) | |
mod = tvm.relay.transform.FastMath()(mod) | |
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod) | |
BindPass = tvm.relay.transform.function_pass(lambda fn, new_mod, ctx: tvm.relay.build_module.bind_params_by_name(fn, params), opt_level=1) | |
mod = BindPass(mod) | |
mod = tvm.relay.transform.FoldConstant()(mod) | |
mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod) | |
mod = tvm.relay.transform.FoldConstant()(mod) | |
mod = tvm.relay.transform.InferType()(mod) | |
mod = tvm.relay.transform.ToMixedPrecision()(mod) | |
mod = tvm.relay.transform.EliminateCommonSubexpr()(mod) | |
mod = tvm.relay.transform.FoldConstant()(mod) | |
mod = tvm.relay.transform.CombineParallelBatchMatmul()(mod) | |
mod = tvm.relay.transform.FoldConstant()(mod) | |
return mod, params, input_shape | |
# weight = '/home/acer/rtmdet_m_syncbn_fast_8xb32-300e_coco_640_640_best_coco_crocs_precision_epoch_218.onnx' | |
weight = '/home/acer/tvm_experiment/resnet50.onnx' | |
work_dir = '/home/acer/test_meta_tensorcore' | |
batch_size = 1 | |
layout = 'NHWC' | |
dtype = "float16" | |
use_sparse = False | |
mod, params, input_shape = get_network( | |
weight, | |
batch_size, | |
layout, | |
dtype=dtype, | |
use_sparse=use_sparse, | |
) | |
if work_dir: | |
context_manager = contextlib.nullcontext(work_dir) | |
else: | |
context_manager = tempfile.TemporaryDirectory() | |
target = tvm.target.Target("nvidia/rtx-3000") | |
space=ms.space_generator.PostOrderApply( | |
sch_rules="cuda-tensorcore", | |
postprocs="cuda-tensorcore", | |
mutator_probs="cuda-tensorcore", | |
) | |
with context_manager as work_dir: # pylint: disable=redefined-argument-from-local | |
# database = ms.relay_integration.tune_relay( | |
# database = ms.database.Database.create(kind="json", work_dir=work_dir) | |
# database = ms.tir_integration.tune_tir( | |
database = ms.relay_integration.tune_relay( | |
mod=mod, | |
params=params, | |
target=target, | |
work_dir=work_dir, | |
#22ms | |
max_trials_global=25000, | |
# max_trials_per_task=64, | |
# max_trials_global=25000, | |
max_trials_per_task=256, | |
num_trials_per_iter=64, | |
builder='local', | |
runner='local', | |
database='json', | |
cost_model='xgb', | |
measure_callbacks='default', | |
task_scheduler='gradient', | |
# space = 'cuda', | |
space = space, | |
strategy="evolutionary", | |
seed=None | |
) | |
with database, tvm.transform.PassContext( | |
opt_level=3, | |
config={"relay.backend.use_meta_schedule": True}, | |
): | |
lib = relay.build(mod, target=target, params=params) | |
lib.export_library('/home/acer/meta_resnet50.tar') | |
# lib.export_library('/home/acer/rtmdet_m_syncbn_fast_8xb32-300e_coco_640_640_best_coco_crocs_precision_epoch_218_meta.tar') | |
# executor_factory = ms.relay_integration.compile_relay( | |
# database=database, | |
# mod=mod, | |
# target=target, | |
# params=params, | |
# backend="graph", | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment