参考: Tensorflow 模型线上部署
构建 TensorFlow Serving Java 客户端
-
docker安装及部署
-
windows下docker安装
-
tf-serving
下载tensorflow服务并使用docker部署,这一步如果占用C盘空间太大的话,可以使用Hyper-v工具将下载的镜像转到其他盘
# 在 cmd 中执行以下命令 docker pull tensorflow/serving # 下载镜像 docker run -itd -p 5000:5000 --name tfserving tensorflow/serving # 运行镜像并指定镜像名 docker ps # 查看镜像id dockerID docker cp ./mnist dockerID:/models # 将 pb 文件夹拷贝到容器中,模型训练见下面 docker exec -it dockerID /bin/bash # 进入到镜像里面 tensorflow_model_server --port=5000 --model_name=mnist --model_base_path=/models/mnist # 容器内运行服务
-
-
训练模型
使用官方给出的mnist样例进行训练,改下代码路径就可以,训练得到pb文件如下,并使用
saved_model_cli show --dir ./mnist/1 --all
命令查看节点名称(后面客户端使用),并将模型复制到docker里面docker cp ./mnist dockerID:/models
,此处注意文件夹层级
-
python端
仿照官方代码 mnist_clien.py编写预测代码
import grpc import tensorflow as tf from tensorflow_serving.apis import predict_pb2 from tensorflow_serving.apis import prediction_service_pb2_grpc server = 'localhost:5000' channel = grpc.insecure_channel(server) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() request.model_spec.name = 'mnist' request.model_spec.signature_name = 'predict_images' test_data_set = mnist_input_data.read_data_sets('./data').test image, label = test_data_set.next_batch(1) request.inputs['images'].CopyFrom(tf.make_tensor_proto(image[0], shape=[1, image[0].size])) pred = stub.Predict(request, 5.0) score = pred.outputs['scores'].float_val print(score) # [1.6178478001727115e-10, 1.6928293322847278e-15, 1.6151154341059737e-05, 0.000658366538118571, 8.010060947860609e-10, 2.2359495588375466e-08, 3.5608297452131843e-13, 0.9993133544921875, 5.620326870570125e-09, 1.1990837265329901e-05]
-
Java端
Java端流程差不多,主要是编译proto麻烦一些
-
proto安装
windows下proto的安装参考windows之google protobuf安装与使用,下载proto-3.4.0并解压,注意目录不要有空格,否则后面编译会报错,找到
protoc.exe
所在路径,我的是D:\protoc-3.4.0-win32\bin
-
pom配置编译proto
此处主要参考构建 TensorFlow Serving Java 客户端,给出的那个proto文件列表太棒了(未理解为什么是这些文件,对java-grpc不熟悉),仿照其流程,下载
tensorflow
及tensorflow-serving
两个项目,复制相应的proto文件出来src/main/proto ├── tensorflow │ └── core │ ├── example │ │ ├── example.proto │ │ └── feature.proto │ ├── framework │ │ ├── attr_value.proto │ │ ├── function.proto │ │ ├── graph.proto │ │ ├── node_def.proto │ │ ├── op_def.proto │ │ ├── resource_handle.proto │ │ ├── tensor.proto │ │ ├── tensor_shape.proto │ │ ├── types.proto │ │ └── versions.proto │ └── protobuf │ ├── meta_graph.proto │ └── saver.proto └── tensorflow_serving └── apis ├── classification.proto ├── get_model_metadata.proto ├── inference.proto ├── input.proto ├── model.proto ├── predict.proto ├── prediction_service.proto └── regression.proto
创建Maven工程,将上面的proto文件放在src/main下面,在pom中添加以下信息,此处额外添加了编译文件的输入及输出目录,否则会报错 protoc did not exit cleanly
<build> <plugins> <plugin> <groupId>org.xolstice.maven.plugins</groupId> <artifactId>protobuf-maven-plugin</artifactId> <version>0.5.0</version> <configuration> <protocExecutable>D:\protoc-3.4.0-win32\bin\protoc.exe</protocExecutable> <protoSourceRoot>${project.basedir}/src/main/proto/</protoSourceRoot> <outputDirectory>${project.basedir}/src/main/resources/</outputDirectory> </configuration> <executions> <execution> <goals> <goal>compile</goal> <goal>compile-custom</goal> </goals> </execution> </executions> </plugin> </plugins> </build> <dependencies> <dependency> <groupId>com.google.protobuf</groupId> <artifactId>protobuf-java</artifactId> <version>3.11.4</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> <version>1.28.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> <version>1.28.0</version> </dependency> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty-shaded</artifactId> <version>1.28.0</version> </dependency> </dependencies>
配置完后,执行
maven -> protobuf:compile
编译,在resources目录下会生成org及tensorflow两个文件夹,将这两个文件夹复制到src/main/java目录下
-
预测
编写java程序进行预测,过程中发现没有
tensorflow/serving/PredictionServiceGrpc.java
这个文件,试了很多方法都没有编译出来,最后是直接把别人的给复制过来了,PredictionServiceGrpc,拷过来后发现报了@java.lang.Override
这几行代码提示有问题,直接将override
注释掉在
src/main/java
下建表及类,编写预测代码,完整代码如下,运行得预测结果package SimpleAdd; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import tensorflow.serving.Model; import org.tensorflow.framework.DataType; import org.tensorflow.framework.TensorProto; import org.tensorflow.framework.TensorShapeProto; import tensorflow.serving.Predict; import tensorflow.serving.PredictionServiceGrpc; public class MnistPredict { public static void main(String[] args) throws Exception { // create a channel for gRPC ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 5000).usePlaintext().build(); PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel); // create a modelspec Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder(); modelSpec.setName("mnist"); modelSpec.setSignatureName("predict_images"); Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder(); request.setModelSpec(modelSpec); // data shape & load data TensorShapeProto.Builder shape = TensorShapeProto.newBuilder(); shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(1)); shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(784)); TensorProto.Builder tensor = TensorProto.newBuilder(); tensor.setTensorShape(shape); tensor.setDtype(DataType.DT_FLOAT); for(int i=0; i<784; i++){ tensor.addFloatVal(0); } request.putInputs("images", tensor.build()); tensor.clear(); // Predict Predict.PredictResponse response = stub.predict(request.build()); System.out.println(response); TensorProto result = response.toBuilder().getOutputsOrThrow("scores"); System.out.println("predict: " + result.getFloatValList()); System.out.println("predict: " + response.getOutputsMap().get("scores").getFloatValList()); // predict: [0.032191742, 0.09621494, 0.06525445, 0.039610844, 0.05699038, 0.46822935, 0.040578533, 0.1338098, 0.009549928, 0.057570033] } }
-