本文介绍使用docker的方法部署tensorflow serving,并提供python和java client代码实例。(本文参考了较多博文和tensorflow官方文档,旨在补充多数博文遗留的坑,和精简官方文档的繁琐)。
为了避免bazel编译源码这个大坑(会报一些奇怪的错误,主要是各个依赖项的版本不对应),本文直接选择docker的方式部署tensorflow serving。
注:只需按照步骤一步一步来,就能从零到部署成功,最后会提供一个使用案例:文本分类模型
1 Docker安装
1.1 Mac环境下安装
参考网站
建议选择手动安装,安装完毕后,选择(Check for Updates)更新到最新版本
1.2 centos环境下安装
前提条件:CentOS 7 上,要求系统为64位、系统内核版本为 3.10 以上,通过指令uname -r 查看自己的系统版本
移除旧的版本:
$ sudo yum remove docker \
docker-client \
docker-client-latest \
docker-common \
docker-latest \
docker-latest-logrotate \
docker-logrotate \
docker-selinux \
docker-engine-selinux \
docker-engine
安装依赖项:
sudo yum install -y yum-utils device-mapper-persistent-data lvm2
添加源信息:
sudo yum-config-manager --add-repo http://mirrors.aliyun.com/docker-ce/linux/centos/docker-ce.repo
更新 yum 缓存:
sudo yum makecache fast
安装 Docker-ce:
sudo yum -y install docker-ce
启动 Docker 后台服务:
sudo systemctl start docker
测试运行 hello-world:
docker run hello-world 或者 直接查看版本 docker --version
2 serving部署
2.1 拉取serving 镜像
docker pull tensorflow/serving
完成之后 查看安装好的镜像
docker images
2.2 导出模型
serving不能直接使用以HDF5和.ckpt方式保存的模型,需要进行一次转化,本文以keras保存的HDF5文件为例进行介绍,.ckpt转换方式大同小异,游客可自行查询。
import tensorflow as tf
from keras import backend as K
from keras.models import Sequential, Model
from os.path import isfile
from keras.models import load_model
import os
def save_model_to_serving(model, export_version, export_path='prod_models'):
print(model.input, model.output)
signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'textdata': model.input}, outputs={'market': model.output})
export_path = os.path.join(
tf.compat.as_bytes(export_path),
tf.compat.as_bytes(str(export_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'market_classification': signature,
},
legacy_init_op=legacy_init_op)
builder.save()
model = load_model('自己的路径/blistm-checkpoint-02e-val_acc_0.96.hdf5')
save_model_to_serving(model, "1", "bgru_serving")#bgru_serving表示转换后的模型会存储到该路径下
模型转化结束后会生成下面几个文件
2.3 运行容器
docker run -p 8500:8500 \
--mount type=bind,source=自己的路径/bgru_serving/,target=/models/market_blstm \
-e MODEL_NAME=market_blstm -t tensorflow/serving
注:测试建议使用8500端口 ,自己的路径->绝对路径
(重点)
各个参数的含义:
- -p 8500:8500 :指的是开放8500这个gRPC端口
- --mount type=bind, source=自己的路径/bgru_serving/, target=/models/market_blstm:把你导出的本地模型文件夹挂载到docker container的/models/market_blstm这个文件夹,tensorflow serving会从容器内的/models/market_blstm文件夹里面找到你的模型
- --MODEL_NAME:模型名字,在导出模型的时候设置的名字
- -t 指定使用tensorflow/serving这个镜像,可以替换其他版本,例如tensorflow/serving:latest-gpu,但你需要
docker pull tensorflow/serving:latest-gpu
把这个镜像拉下来
3 client客户端
3.1 python 案例
注:最好使用python3.5+,不然如果使用高版本的tensorflow会报错
安装依赖库sudo pip3 install tensorflow-serving-api
客户端代码
from __future__ import print_function
from grpc.beta import implementations
import tensorflow as tf
import numpy as np
import re,json,jieba,time
import codecs
import random,time
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2
def loadData(filename): #加载json文件 生成字典
with codecs.open(filename,'r','utf-8') as fr:
resdict = json.load(fr)
return resdict
vocab = loadData('vocab_bgru.dict')# 加载词典 ,格式:"中国":12045
def denoise(text): #文本预处理并粉刺,再根据embegging所需的词典生成词的索引矩阵----处理单条文本数据
x_train_word_ids = []
tem = []
patten=re.compile(r'(https|http)?:\/\/(\w|\.|\/|\?|\=|\&|\%)*\b',re.S)
line = text.strip()
line = patten.sub('',line.decode("utf-8","ignore"))
line = re.sub(r'{url(.*)网页链接}','',line)
line = line.replace('\\','').replace('\n',' ').replace('https://',' ')
wordlist = [emt.strip() for emt in jieba.cut(line) if len(emt.strip())>=2]
for i,word in enumerate(wordlist):
try:code = vocab[word]
except:
try:code = vocab[word.encode('utf-8')]
except:continue
tem.append(code)
x_train_word_ids.append(tem)
if len(x_train_word_ids)==0:return [[0]]
return x_train_word_ids
def pad_sequences(x_train_word_ids,maxlen=64): #根据denoise函数得到的一条文本的索引矩阵生成符合lstm输入的词向量
len_x = len(x_train_word_ids[0])
if len_x>maxlen:
res = [x_train_word_ids[0][i] for i in range(len_x-maxlen,len_x)]
return res
else:
res = [0]*maxlen
for i,emt in enumerate(x_train_word_ids[0]):
res[maxlen-len_x+i]=emt
return res
tf.app.flags.DEFINE_string('server', '127.0.0.1:8500',
'PredictionService host:port') #ip和端口,ip可换成要连接的服务器ip
FLAGS = tf.app.flags.FLAGS
start_time = time.time()
batch_size = 120
host,port = FLAGS.server.split(":")
channel = implementations.insecure_channel(host,int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'market_blstm' # 这个name跟tensorflow_model_server --model_name="market_blstm" 对应
request.model_spec.signature_name = 'market_classification' # 这个signature_name 跟2.2模型导出中的market_classification 对应
text_list = ['吴亦凡同款 Sup扎染卫衣 全身顶级数码直喷 印花带做旧感 就是看起来脏脏的 一件衣服印花大几十块 完美还原面料为420G毛圈轻捉毛 质感很好',"360儿童5周年不止5折# 360儿童手表五周年&双十一特惠! 喜欢![失望]"]
x_train = np.array([pad_sequences(denoise(text)) for text in text_list])
request.inputs['textdata'].CopyFrom(
tf.contrib.util.make_tensor_proto(x_train, shape=[batch_size,64],dtype=tf.float32)) # shape跟 keras的model.input类型对应,且textdata对应2.2中的textdata
result = stub.Predict(request, 10.0)
reslist = result.outputs['market'].float_val
print(reslist)
结果如下:
[0.013646061532199383, 0.9863539338111877, 0.16853764653205872, 0.8314623832702637]
每两个是一对预测数据,例如0.013646061532199383, 0.9863539338111877
表示分别表示text_list中第一条数据
属于0类的概率为0.013646061532199383
,1类的概率为0.9863539338111877
3.2 java案例
pom.xml文件中的依赖项:
<dependencies>
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty</artifactId>
<version>1.7.0</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<version>2.0.7.Final</version>
</dependency>
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>1.0.2</version>
</dependency>
<dependency>
<groupId>net.sf.json-lib</groupId>
<artifactId>json-lib</artifactId>
<version>2.4</version>
<classifier>jdk15</classifier>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>
</dependencies>
具体代码:
import com.huaban.analysis.jieba.JiebaSegmenter;
import com.huaban.analysis.jieba.WordDictionary;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import net.sf.json.JSONObject;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;
import tensorflow.serving.Model;
import tensorflow.serving.Predict;
import tensorflow.serving.PredictionServiceGrpc;
public class TensorServClient {
PredictionServiceGrpc.PredictionServiceBlockingStub stub = null;
private JiebaSegmenter segmenter;
private JSONObject json;
private static int maxlen = 64; //padding的最大长度
private static int batch = 200;
public TensorServClient(){
ManagedChannel channel = ManagedChannelBuilder.forAddress("127.0.0.1",8500).usePlaintext(true).build();
//这里还是先用block模式
stub = PredictionServiceGrpc.newBlockingStub(channel);
WordDictionary dictAdd = WordDictionary.getInstance();
dictAdd.loadUserDict(Paths.get("jiebaextradic_java.dict"));//加载自定义词典
segmenter = new JiebaSegmenter();
try {
json = LoadJsonFile.load("vocab_bgru.dict"); //加载词位置索引词典 ,格式:"中国":12045
}catch (Exception ex){
ex.printStackTrace();
}
}
private ArrayList<Integer> denoise(String line){
ArrayList<Integer>x_train_word_ids = new ArrayList<Integer>();
line = line.replaceAll("(http|ftp|https):\\/\\/[\\w\\-_]+(\\.[\\w\\-_]+)+([\\w\\-\\.,@?^=%&:/~\\+#]*[\\w\\-\\@?^=%&/~\\+#])?","");
line = line.replaceAll("\\{url(.*)网页链接\\}","");
line = line.replaceAll("\\\\","").replaceAll("\\r|\\n","").replaceAll("https://","");
ArrayList<String> wordjiebaList = (ArrayList<String>) segmenter.sentenceProcess(line);
for (String word:wordjiebaList) {
try {
if (this.json.containsKey(word)){
x_train_word_ids.add(this.json.getInt(word));
}
}catch (Exception e){
x_train_word_ids.add(0);
}
}
return x_train_word_ids;
}
private float[] padSequences(ArrayList<Integer>x_train_word_ids){
float []res=new float[maxlen];
int len_x = x_train_word_ids.size();
if (len_x>maxlen){
for (int i = len_x-maxlen,j=0; i < len_x; i++,j++) {
res[j]=x_train_word_ids.get(i);
}
return res;
}else {
for (int i = 0; i < len_x; i++) {
res[maxlen-len_x+i]=x_train_word_ids.get(i);
}
return res;
}
}
private float[][]gen_predict_data(String []textlist){
float [][] predict_data = new float[batch][maxlen];
for (int i = 0; i < textlist.length; i++) {
predict_data[i]=padSequences(denoise(textlist[i]));
}
return predict_data;
}
public void predict(String[] textlist){
// //创建请求
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
//模型名称和模型方法名预设
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName("market_blstm");
modelSpecBuilder.setSignatureName("market_classification");
predictRequestBuilder.setModelSpec(modelSpecBuilder);
//设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_FLOAT);
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(batch));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(maxlen));
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
float[][]featuresTensorData = gen_predict_data(textlist);
for (int i = 0; i < featuresTensorData.length; ++i) {
for (int j = 0; j < featuresTensorData[i].length; ++j) {
tensorProtoBuilder.addFloatVal(featuresTensorData[i][j]);
}
}
predictRequestBuilder.putInputs("textdata",tensorProtoBuilder.build());
//访问并获取结果
Predict.PredictResponse predictResponse = stub.predict(predictRequestBuilder.build());
TensorProto result = predictResponse.toBuilder().getOutputsOrThrow("market");
List<Float> reslist = result.getFloatValList();
}
public static void main(String[] args) throws Exception{
long startTime = System.currentTimeMillis();
TensorServClient tensorServClient = new TensorServClient();
long midTime = System.currentTimeMillis();
String[] textlist = {"吴亦凡同款 Sup扎染卫衣 全身顶级数码直喷 印花带做旧感 就是看起来脏脏的 一件衣服印花大几十块 完美还原面料为420G毛圈轻捉毛 质感很好","360儿童5周年不止5折# 360儿童手表五周年&双十一特惠! 喜欢![失望]",....};//这个数组的长度为 batch ,方便批处理
tensorServClient.predict(textlist);
}
}
注:java 案例中textlist的长度为batch,每个位置上是一条文本;结果与python案例保持一致,亦是两个一对