https://www.jianshu.com/p/d673c9507988
通过简单运行了官网例子,对tensorflow serving有了大致的了解,但是怎么把自己的模型发布成服务呢?现在通过一个小例子来学习下。
0. 介绍
这里介绍两种保存模型的方式,发布服务需要的不再是之前保存的ckpt格式数据,而是export出来的模型或者pb模型。通过这两种方式把模型准备好,之后只需要挂在到指定路径下,就可以起服务了。
1. 1 exporter 模型
把官方的half_plus_two简单修改成了half_plus_ten。
与我们保存ckpt不同,需要调用的接口是:
from tensorflow.contrib.session_bundle import exporter
需要把输入输出给重新定义下,然后再用接口导出。
import tensorflow as tf
from tensorflow.contrib.session_bundle import exporter
def Export():
export_path = "model/half_plus_ten"
with tf.Session() as sess:
# Make model parameters a&b variables instead of constants to
# exercise the variable reloading mechanisms.
a = tf.Variable(0.5)
b = tf.Variable(10.0)
# Calculate, y = a*x + b
# here we use a placeholder 'x' which is fed at inference time.
x = tf.placeholder(tf.float32)
y = tf.add(tf.multiply(a, x), b)
# Run an export.
tf.global_variables_initializer().run()
export = exporter.Exporter(tf.train.Saver())
export.init(named_graph_signatures={
"inputs": exporter.generic_signature({"x": x}),
"outputs": exporter.generic_signature({"y": y}),
"regress": exporter.regression_signature(x, y)
})
export.export(export_path, tf.constant(123), sess)
def main(_):
Export()
if __name__ == "__main__":
tf.app.run()
保存好的模型看起来很像ckpt,但是再checkpoint里面可以看到,是“export”。 “00000123”这个文件名是自动生成的,我也不知道为什么会刚好是这个数字。
保存好的模型
1.2 保存pb模型
https://www.jianshu.com/p/9221fbf52c55 通过这个教程,我们把模型保存为pb格式。同样把这个模型文件夹挂在到docker相应的目录下。
保存为pb模型
2. 通过docker起服务
要指定端口,挂载目录,docker才能访问这个模型,挂在的目录得是绝对路径。
- export之后的模型挂载。
docker run -t --rm -p 8501:8501 \
-v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \
-e MODEL_NAME=half_plus_ten \
tensorflow/serving
- pb模型需要修改挂载路径,可以重新给模型起名字,这里还是用上面的名字“half_plus_ten"。
docker run -t --rm -p 8501:8501 \
-v "$(pwd)/pb_model:/models/half_plus_ten" \
-e MODEL_NAME=half_plus_ten \
tensorflow/serving
3. 测试服务
给它几个值来测试下这个服务。
curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict
能得到half plus ten这个结果!
输出正确
用python代码访问服务
import os
import requests
from time import time
import numpy as np
url = 'http://localhost:8501/v1/models/half_plus_ten:predict'
a = np.array([1,2 ,3,4])
predict_request = '{"instances" : [{"input": %s}]}' % list(a) # 一定要list才能传输,不然json错误
print("start")
start_time = time()
r = requests.post(url,data=predict_request)
print(r.content)
end_time = time()
Tips:
代码改写自官方例子:https://github.com/tensorflow/serving/blob/master/tensorflow_serving/servables/tensorflow/testdata/export_half_plus_two.py
代码和模型都放在:
https://github.com/xxlxx1/learing_tf_serving