Created
September 1, 2019 14:46
-
-
Save koyo922/a04d363bf3cad27e0b022e081f29c086 to your computer and use it in GitHub Desktop.
TensorFlow Serving Usage Demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# vim: tabstop=4 shiftwidth=4 expandtab number | |
""" | |
关于 TFServing 用法的详细demo, 包括: | |
1. 训练一个MNIST模型 | |
2. 将其导出为 SavedModel 格式 | |
3. 使用TFServing Docker Image封装上述 SavedModel格式的模型,提供服务 | |
4. 演示client的 gRPC/REST 用法 | |
5. 多版本模型迭代(热插拔) | |
6. Serving并发度调优 | |
cat <<'EOF' > /tmp/mnist/models.config | |
model_config_list: { | |
config: { | |
name: "mnist", | |
base_path: "/models/mnist", | |
model_platform: "tensorflow", | |
model_version_policy: { | |
all: {} | |
} | |
} | |
} | |
EOF | |
cat <<'EOF' > /tmp/mnist/batching_parameters.txt | |
num_batch_threads { value: 32 } | |
batch_timeout_micros { value: 500 } | |
max_batch_size { value: 512 } | |
pad_variable_length_inputs: true | |
EOF | |
docker run -p 8500:8500 -p 8501:8501 \ | |
--rm --name mnist_serving \ | |
--mount type=bind,source=/tmp/mnist,target=/models/mnist \ | |
-e MODEL_NAME=mnist -t tensorflow/serving \ | |
--model_config_file=/models/mnist/models.config \ | |
--enable_batching=true --batching_parameters_file=/models/mnist/batching_parameters.txt | |
参考: | |
- https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py | |
- https://www.tensorflow.org/tfx/serving/serving_basic | |
- https://www.cnblogs.com/marsggbo/p/10057220.html | |
- https://blog.csdn.net/JerryZhang__/article/details/86516428 | |
- https://www.kaggle.com/scolianni/mnistasjpg/downloads/mnistasjpg.zip/1 | |
- https://www.jianshu.com/p/e2184bf63505 | |
Authors: qianweishuo<qzy922@gmail.com> | |
Date: 2019/9/1 上午11:35 | |
""" | |
from __future__ import print_function, unicode_literals # 针对 py2 的一点兼容指令 | |
import os | |
import grpc # pip install grpcio for py3 | |
# 可以从官方仓库下载此文件,并放在某个 $PYTHONPATH 能找到的路径 | |
# wget -P ./ https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_input_data.py && export PYTHONPATH=. | |
# noinspection PyUnresolvedReferences | |
import mnist_input_data | |
import tensorflow as tf | |
# 下面几行作用相当于 argparse 里面添加定义参数; 可以从命令行覆盖, e.g. '--export_dir "/tmp/my_mnist"' | |
tf.app.flags.DEFINE_integer('training_iteration', 1000, 'number of training iterations.') | |
tf.app.flags.DEFINE_integer('model_version', 1, 'version number of the model.') | |
tf.app.flags.DEFINE_string('work_dir', '/tmp', 'Working directory.') | |
tf.app.flags.DEFINE_string('export_dir', '/tmp/mnist', 'To which exports the model') | |
FLAGS = tf.app.flags.FLAGS | |
def training_model(): | |
""" | |
训练模型,返回 session | |
""" | |
print('Training model...') | |
# 准备数据 | |
mnist = mnist_input_data.read_data_sets(FLAGS.work_dir, one_hot=True) | |
# 建图,构造loss, trainer | |
sess = tf.InteractiveSession() | |
serialized_tf_example = tf.placeholder(tf.string, name='tf_example') # 这是x序列化成str的格式 | |
tf_example = tf.parse_example(serialized=serialized_tf_example, # 反序列化 | |
features={'x': tf.FixedLenFeature(shape=[784], dtype=tf.float32), }) | |
x = tf.identity(tf_example['x'], name='x') # use tf.identity() to assign name | |
y_true = tf.placeholder('float', shape=[None, 10]) | |
w = tf.Variable(tf.zeros([784, 10])) | |
b = tf.Variable(tf.zeros([10])) | |
y_pred = tf.nn.softmax(tf.matmul(x, w) + b, name='y') | |
cross_entropy = -tf.reduce_sum(y_true * tf.log(y_pred)) | |
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) | |
# 执行训练 | |
sess.run(tf.global_variables_initializer()) | |
for _ in range(FLAGS.training_iteration): | |
batch = mnist.train.next_batch(50) | |
train_step.run(feed_dict={x: batch[0], y_true: batch[1]}) | |
# 统计训练loss 和 accuracy | |
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_true, 1)) | |
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float')) | |
print('training accuracy %g' % sess.run( | |
accuracy, feed_dict={ | |
x: mnist.test.images, | |
y_true: mnist.test.labels | |
})) | |
print('Done training!') | |
return sess, serialized_tf_example, x, y_pred | |
def export_model(sess, serialized_tf_example, x, y_pred): | |
# 注意这里拼接路径之前会转成bytes,为了字符编码兼容性,建议使用纯ascii路径 | |
export_path = os.path.join(tf.compat.as_bytes(FLAGS.export_dir), tf.compat.as_bytes(str(FLAGS.model_version))) | |
print('Exporting trained model to', export_path) | |
builder = tf.saved_model.builder.SavedModelBuilder(export_path) | |
# 构造签名 signature_def_map; 支持两个不同层级的接口 | |
# 1. 直接使用tensor格式的输入x; 返回y tensor | |
prediction_signature = (tf.saved_model.signature_def_utils.build_signature_def( | |
inputs={'images': tf.saved_model.utils.build_tensor_info(x)}, | |
outputs={'scores': tf.saved_model.utils.build_tensor_info(y_pred)}, | |
method_name="tensorflow/serving/predict")) | |
# 2. 接受序列化格式的serialized_tf_example; 返回 结果类别及其对应分数; 这种比较复杂,没有相应client的用法演示 | |
classification_inputs = tf.saved_model.utils.build_tensor_info(serialized_tf_example) # 模型输入 | |
values, indices = tf.nn.top_k(y_pred, 10) # top10 对应的softmax值和下标 | |
prediction_classes = (tf.contrib.lookup | |
.index_to_string_table_from_tensor(tf.constant([str(i) for i in range(10)])) | |
.lookup(tf.to_int64(indices))) # 将下标查表转换为str | |
classification_outputs_classes = tf.saved_model.utils.build_tensor_info(prediction_classes) # 输出的各个可能类别 | |
classification_outputs_scores = tf.saved_model.utils.build_tensor_info(values) # 及其各自的分值 | |
classification_signature = (tf.saved_model.signature_def_utils.build_signature_def( | |
# "inputs" 也可以写成 tf.saved_model.signature_constants.CLASSIFY_INPUTS; 其它常量key类似 | |
inputs={"inputs": classification_inputs}, | |
outputs={"classes": classification_outputs_classes, "scores": classification_outputs_scores}, | |
method_name="tensorflow/serving/classify")) | |
# 使用上述两个签名,进行模型输出 | |
builder.add_meta_graph_and_variables( | |
sess=sess, tags=["serve"], # 为 meta_graph添加 "serve"标签,或许便于TF框架进行一些优化 | |
signature_def_map={'predict_images': prediction_signature, "serving_default": classification_signature}, | |
main_op=tf.tables_initializer(), | |
strip_default_attrs=True) | |
builder.save() | |
print('Done exporting!') | |
def get_test_images(): | |
test_data_set = mnist_input_data.read_data_sets(FLAGS.work_dir).test | |
images, _labels = test_data_set.next_batch(3) | |
assert (3, 784) == images.shape | |
return images | |
def client_grpc(grpc_server='localhost:8500'): | |
from tensorflow_serving.apis import predict_pb2 # pip install tensorflow-serving-api | |
from tensorflow_serving.apis import prediction_service_pb2_grpc | |
stub = prediction_service_pb2_grpc.PredictionServiceStub(grpc.insecure_channel(grpc_server)) | |
request = predict_pb2.PredictRequest() | |
request.model_spec.name = 'mnist' | |
request.model_spec.signature_name = 'predict_images' | |
request.model_spec.version.value = 1 # 默认是最高版本号 | |
images = get_test_images() # 得到ndarray格式的输入 | |
request.inputs['images'].CopyFrom(tf.contrib.util.make_tensor_proto(images, shape=images.shape)) | |
result = stub.Predict(request, 5.0) # 5 secs timeout | |
scores = tf.contrib.util.make_ndarray(result.outputs['scores']) # 从中提取自己感兴趣的部分 | |
assert (images.shape[0], 10) == scores.shape | |
def client_rest(rest_server='http://localhost:8501/v1/models/mnist/versions/1:predict'): | |
import requests | |
import json | |
import numpy as np | |
images = get_test_images() # 得到ndarray格式的输入 | |
response = requests.post(rest_server, data=json.dumps( # 注意ndarray无法直接dump,必须先tolist() | |
{"signature_name": "predict_images", "instances": images.tolist()})) | |
assert (images.shape[0], 10) == np.array(response.json()['predictions']).shape | |
def main(_): | |
# python tfserving_demo.py --model_version 2 | |
sess, serialized_tf_example, x, y = training_model() | |
export_model(sess, serialized_tf_example, x, y) | |
client_grpc() | |
client_rest() | |
if __name__ == '__main__': | |
# 相当于 FLAGS=argparse.parse(argv[1:]) 然后 main() | |
tf.app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment