Created
September 9, 2022 23:05
-
-
Save antiagainst/0d2cc3463299b497b797662ddbffa3db to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copied from https://colab.sandbox.google.com/github/iree-org/iree/blob/main/samples/colab/resnet.ipynb | |
# Running the following commands to install needed packages | |
# pip install --upgrade iree-compiler iree-runtime iree-tools-tf -f https://github.com/iree-org/iree/releases | |
# pip install --upgrade tf-nightly | |
from iree import runtime as ireert | |
from iree import compiler as ireec | |
from iree.tf.support import module_utils | |
import tensorflow as tf | |
from absl import app | |
print("TensorFlow version: ", tf.__version__) | |
INPUT_SHAPE = [1, 224, 224, 3] | |
tf_model = tf.keras.applications.resnet50.ResNet50(weights="imagenet", | |
include_top=True, | |
input_shape=tuple( | |
INPUT_SHAPE[1:])) | |
# Wrap the model in a tf.Module to compile it with IREE. | |
class ResNetModule(tf.Module): | |
def __init__(self): | |
super(ResNetModule, self).__init__() | |
self.m = tf_model | |
self.m.predict = lambda x: self.m.call(x, training=False) | |
self.predict = tf.function( | |
input_signature=[tf.TensorSpec(INPUT_SHAPE, tf.float32)])( | |
tf_model.predict) | |
def main(argv): | |
backend = module_utils.BackendInfo("iree_vulkan") | |
# Import TF model into MLIR format. This generates quite a few artifacts | |
# inside the directory; they are the model representation at different | |
# levels. We only need the mhlo representation there. It also compiles | |
# the model for vulkan, but that's using the default parameters; so we | |
# will discard that and recompile later (as there is no way to control | |
# the compilation options here). | |
backend.compile_from_class(ResNetModule, | |
exported_names=["predict"], | |
artifacts_dir="./iree-resnet50-artifacts") | |
# Read in the imported mhlo representation of the model. | |
with open("./iree-resnet50-artifacts/iree_input.mlir") as f: | |
mhlo_source = f.read() | |
# Compile using iree-compile wrapper. Here we'll have access to all | |
# developer command-line option controls. | |
compilation_args = [ | |
"--iree-vulkan-target-triple=rdna2-unknown-linux", | |
"--mlir-print-debuginfo=false", | |
# Add more command-line options you'd like to iree-compile here: | |
#"--mlir-print-ir-after=iree-hal-materialize-interfaces", | |
#"--mlir-elide-elementsattrs-if-larger=8", | |
] | |
blob = ireec.compile_str(mhlo_source, | |
target_backends=["vulkan"], | |
extra_args=compilation_args, | |
input_type="mhlo") | |
# Write out the IREE module blob. | |
with open("./iree-resnet50-artifacts/amd-resnet50.vmfb", "wb") as f: | |
f.write(blob) | |
if __name__ == "__main__": | |
app.run(main) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment