Created
April 10, 2020 05:28
-
-
Save icemelon/cd41746fefac55d033f06059df69c747 to your computer and use it in GitHub Desktop.
test_nth
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tvm | |
from tvm import relay | |
from tvm.relay.ty import TupleType, TensorType | |
from tvm.relay.prelude import Prelude | |
from tvm.runtime.container import ADT | |
def _get_relay_input_vars(input_shapes, prelude): | |
def _is_int_seq(seq): | |
return len(seq) > 0 and all([isinstance(i, int) for i in seq]) | |
def get_relay_ty(ishape): | |
if _is_int_seq(ishape) or len(ishape) == 0: | |
return TensorType(ishape) | |
elif isinstance(ishape, tuple): | |
return TupleType([get_relay_ty(elem) for elem in ishape]) | |
elif isinstance(ishape, list): | |
assert len(ishape) > 0 | |
elem_tys = [get_relay_ty(s) for s in ishape] | |
msg = "List elements should have identical types" | |
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg | |
return prelude.l(elem_tys[0]) | |
raise NotImplementedError("unsupported input type") | |
input_types = [(tup[0], get_relay_ty(tup[1])) for tup in input_shapes] | |
return [relay.expr.var(name, type_annotation=itype) | |
for name, itype in input_types] | |
def convert_to_list_adt(py_lst, prelude): | |
adt_lst = ADT(prelude.nil.tag, []) | |
for arr in reversed(py_lst): | |
adt_lst = ADT(prelude.cons.tag, [relay.const(arr), adt_lst]) | |
return adt_lst | |
def test_nth(): | |
batch, hidden_size = 2, 4 | |
input_name = "states" | |
input_shapes = [(input_name, [(batch, hidden_size), (batch, hidden_size)])] | |
state_list = [np.random.uniform(size=shape) for shape in input_shapes[0][1]] | |
mod = tvm.IRModule() | |
prelude = Prelude(mod) | |
adt_obj = convert_to_list_adt(state_list, prelude) | |
params = {input_name: adt_obj} | |
input_var = _get_relay_input_vars(input_shapes, prelude)[0] | |
mod["main"] = tvm.relay.Function([input_var], prelude.nth(input_var, relay.const(0))) | |
executor = relay.create_executor("vm", mod=mod, ctx=tvm.cpu(0), target="llvm") | |
evaluator = executor.evaluate() | |
evaluator(**params) | |
test_nth() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment