Last active
July 12, 2024 13:07
-
-
Save wangjiezhe/050854e2a12a4f05eab66da43b579dd9 to your computer and use it in GitHub Desktop.
Accelerating Inference in TensorFlow with TensorRT User Guide https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow import keras | |
# Define a simple sequential model | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dropout(0.2), | |
tf.keras.layers.Dense(10) | |
]) | |
model.compile(optimizer='adam', | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=['accuracy']) | |
mnist = tf.keras.datasets.mnist | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train, x_test = x_train / 255.0, x_test / 255.0 | |
x_train = tf.cast(x_train, dtype=tf.float32) | |
y_train = tf.cast(y_train, dtype=tf.float32) | |
x_test = tf.cast(x_test, dtype=tf.float32) | |
y_test = tf.cast(y_test, dtype=tf.float32) | |
# Train the model | |
model.fit(x_train, y_train, epochs=5) | |
# Evaluate your model accuracy | |
model.evaluate(x_test, y_test, verbose=2) | |
# Save model in the saved_model format | |
SAVED_MODEL_DIR="./models/native_saved_model" | |
tf.saved_model.save(model, SAVED_MODEL_DIR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.python.compiler.tensorrt import trt_convert as trt | |
# Instantiate the TF-TRT converter | |
converter = trt.TrtGraphConverterV2( | |
input_saved_model_dir=SAVED_MODEL_DIR, | |
precision_mode=trt.TrtPrecisionMode.FP32 | |
) | |
# Convert the model into TRT compatible segments | |
trt_func = converter.convert() | |
converter.summary() | |
MAX_BATCH_SIZE=128 | |
def input_fn(): | |
batch_size = MAX_BATCH_SIZE | |
x = x_test[0:batch_size, :] | |
yield [x] | |
converter.build(input_fn=input_fn) | |
OUTPUT_SAVED_MODEL_DIR="./models/tftrt_saved_model" | |
converter.save(output_saved_model_dir=OUTPUT_SAVED_MODEL_DIR) | |
# Get batches of test data and run inference through them | |
infer_batch_size = MAX_BATCH_SIZE // 2 | |
for i in range(10): | |
print(f"Step: {i}") | |
start_idx = i * infer_batch_size | |
end_idx = (i + 1) * infer_batch_size | |
x = x_test[start_idx:end_idx, :] | |
trt_func(x) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment