Tensorflow模型部署

参考:  Tensorflow 模型线上部署
    构建 TensorFlow Serving Java 客户端

  • docker安装及部署

    • windows下docker安装

    • tf-serving

        下载tensorflow服务并使用docker部署,这一步如果占用C盘空间太大的话,可以使用Hyper-v工具将下载的镜像转到其他盘

      # 在 cmd 中执行以下命令
      docker pull tensorflow/serving   # 下载镜像
      docker run -itd -p 5000:5000 --name tfserving tensorflow/serving   # 运行镜像并指定镜像名
      docker ps  # 查看镜像id  dockerID
      docker cp ./mnist dockerID:/models  # 将 pb 文件夹拷贝到容器中,模型训练见下面
      
      docker exec -it dockerID /bin/bash  # 进入到镜像里面
      tensorflow_model_server --port=5000 --model_name=mnist --model_base_path=/models/mnist  # 容器内运行服务
      
  • 训练模型

      使用官方给出的mnist样例进行训练,改下代码路径就可以,训练得到pb文件如下,并使用 saved_model_cli show --dir ./mnist/1 --all 命令查看节点名称(后面客户端使用),并将模型复制到docker里面docker cp ./mnist dockerID:/models,此处注意文件夹层级

    mnist-pb.png

    sigdef.png
    models.png
  • python端

      仿照官方代码 mnist_clien.py编写预测代码

    import grpc
    import tensorflow as tf
    from tensorflow_serving.apis import predict_pb2
    from tensorflow_serving.apis import prediction_service_pb2_grpc
    
    server = 'localhost:5000'
    
    channel = grpc.insecure_channel(server)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'mnist'
    request.model_spec.signature_name = 'predict_images'
    
    test_data_set = mnist_input_data.read_data_sets('./data').test
    image, label = test_data_set.next_batch(1)
    request.inputs['images'].CopyFrom(tf.make_tensor_proto(image[0], shape=[1, image[0].size]))
    pred = stub.Predict(request, 5.0)
    score = pred.outputs['scores'].float_val
    print(score)
    # [1.6178478001727115e-10, 1.6928293322847278e-15, 1.6151154341059737e-05, 0.000658366538118571, 8.010060947860609e-10, 2.2359495588375466e-08, 3.5608297452131843e-13, 0.9993133544921875, 5.620326870570125e-09, 1.1990837265329901e-05]
    
  • Java端

      Java端流程差不多,主要是编译proto麻烦一些

    • proto安装

        windows下proto的安装参考windows之google protobuf安装与使用,下载proto-3.4.0并解压,注意目录不要有空格,否则后面编译会报错,找到protoc.exe所在路径,我的是D:\protoc-3.4.0-win32\bin

    • pom配置编译proto

        此处主要参考构建 TensorFlow Serving Java 客户端,给出的那个proto文件列表太棒了(未理解为什么是这些文件,对java-grpc不熟悉),仿照其流程,下载tensorflowtensorflow-serving两个项目,复制相应的proto文件出来

      src/main/proto
      ├── tensorflow
      │   └── core
      │       ├── example
      │       │   ├── example.proto
      │       │   └── feature.proto
      │       ├── framework
      │       │   ├── attr_value.proto
      │       │   ├── function.proto
      │       │   ├── graph.proto
      │       │   ├── node_def.proto
      │       │   ├── op_def.proto
      │       │   ├── resource_handle.proto
      │       │   ├── tensor.proto
      │       │   ├── tensor_shape.proto
      │       │   ├── types.proto
      │       │   └── versions.proto
      │       └── protobuf
      │           ├── meta_graph.proto
      │           └── saver.proto
      └── tensorflow_serving
          └── apis
              ├── classification.proto
              ├── get_model_metadata.proto
              ├── inference.proto
              ├── input.proto
              ├── model.proto
              ├── predict.proto
              ├── prediction_service.proto
              └── regression.proto
      

        创建Maven工程,将上面的proto文件放在src/main下面,在pom中添加以下信息,此处额外添加了编译文件的输入及输出目录,否则会报错 protoc did not exit cleanly

      <build>
          <plugins>
              <plugin>
                  <groupId>org.xolstice.maven.plugins</groupId>
                  <artifactId>protobuf-maven-plugin</artifactId>
                  <version>0.5.0</version>
                  <configuration>
                      <protocExecutable>D:\protoc-3.4.0-win32\bin\protoc.exe</protocExecutable>
                      <protoSourceRoot>${project.basedir}/src/main/proto/</protoSourceRoot>
                      <outputDirectory>${project.basedir}/src/main/resources/</outputDirectory>
                  </configuration>
                  <executions>
                      <execution>
                          <goals>
                              <goal>compile</goal>
                              <goal>compile-custom</goal>
                          </goals>
                      </execution>
                  </executions>
              </plugin>
          </plugins>
      </build>
      
      <dependencies>
          <dependency>
              <groupId>com.google.protobuf</groupId>
              <artifactId>protobuf-java</artifactId>
              <version>3.11.4</version>
          </dependency>
          <dependency>
              <groupId>io.grpc</groupId>
              <artifactId>grpc-protobuf</artifactId>
              <version>1.28.0</version>
          </dependency>
          <dependency>
              <groupId>io.grpc</groupId>
              <artifactId>grpc-stub</artifactId>
              <version>1.28.0</version>
          </dependency>
          <dependency>
              <groupId>io.grpc</groupId>
              <artifactId>grpc-netty-shaded</artifactId>
              <version>1.28.0</version>
          </dependency>
      </dependencies>
      

        配置完后,执行maven -> protobuf:compile编译,在resources目录下会生成org及tensorflow两个文件夹,将这两个文件夹复制到src/main/java目录下

      proto.png

    • 预测

        编写java程序进行预测,过程中发现没有tensorflow/serving/PredictionServiceGrpc.java这个文件,试了很多方法都没有编译出来,最后是直接把别人的给复制过来了,PredictionServiceGrpc,拷过来后发现报了@java.lang.Override这几行代码提示有问题,直接将override注释掉

      src/main/java下建表及类,编写预测代码,完整代码如下,运行得预测结果

      package SimpleAdd;
      
      import io.grpc.ManagedChannel;
      import io.grpc.ManagedChannelBuilder;
      import tensorflow.serving.Model;
      import org.tensorflow.framework.DataType;
      import org.tensorflow.framework.TensorProto;
      import org.tensorflow.framework.TensorShapeProto;
      
      import tensorflow.serving.Predict;
      import tensorflow.serving.PredictionServiceGrpc;
      
      
      public class MnistPredict {
          public static void main(String[] args) throws Exception {
              // create a channel for gRPC
              ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 5000).usePlaintext().build();
              PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
      
              // create a modelspec
              Model.ModelSpec.Builder modelSpec = Model.ModelSpec.newBuilder();
              modelSpec.setName("mnist");
              modelSpec.setSignatureName("predict_images");
              Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
              request.setModelSpec(modelSpec);
      
              // data shape & load data
              TensorShapeProto.Builder shape = TensorShapeProto.newBuilder();
              shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
              shape.addDim(TensorShapeProto.Dim.newBuilder().setSize(784));
              TensorProto.Builder tensor = TensorProto.newBuilder();
              tensor.setTensorShape(shape);
              tensor.setDtype(DataType.DT_FLOAT);
              for(int i=0; i<784; i++){
                  tensor.addFloatVal(0);
              }
              request.putInputs("images", tensor.build());
              tensor.clear();
      
              // Predict 
              Predict.PredictResponse response = stub.predict(request.build());
              System.out.println(response);
              TensorProto result = response.toBuilder().getOutputsOrThrow("scores");
              System.out.println("predict: " + result.getFloatValList());
              System.out.println("predict: " + response.getOutputsMap().get("scores").getFloatValList());
              // predict: [0.032191742, 0.09621494, 0.06525445, 0.039610844, 0.05699038, 0.46822935, 0.040578533, 0.1338098, 0.009549928, 0.057570033]
          }
      }
      
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。