tf2.0 tfserving

2018年写过tf保存为pb使用tfserving,现在发现tf2.0环境运行不了了,于是重新研究下
官方例子也变了,使用tf.compat 兼容api实现
简化了官方版本,更清晰简洁,如下所示:

import tensorflow as tf

def export():
    export_path = "model/half_plus_ten/1"
    with tf.compat.v1.keras.backend.get_session() as sess:
        # 定义模型,参数、输入输入
        a = tf.Variable(100.0)
        b = tf.Variable(0.05)
        x = tf.compat.v1.placeholder(tf.float32)
        y = tf.add(tf.multiply(a, x), b)
        sess.run(tf.compat.v1.global_variables_initializer())

        # 存储为pb格式
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path)
        #输入输出必须是tensor,签名化       
        inputs = tf.compat.v1.saved_model.utils.build_tensor_info(x)
        outputs = tf.compat.v1.saved_model.utils.build_tensor_info(y)
        prediction_signature = (
            tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
                inputs={'input': inputs},
                outputs={'output': outputs},
                method_name=tf.compat.v1.saved_model.signature_constants.PREDICT_METHOD_NAME))

        builder.add_meta_graph_and_variables(
            sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict':
                    prediction_signature,         
        # 不知道为什么需要两次签名,但是少了下面这个会报错 
        #"error": "Serving signature name: \"serving_default\" not found in signature def"             
                    tf.compat.v1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                            prediction_signature,
            },
            main_op=tf.compat.v1.tables_initializer(),
            strip_default_attrs=True)

        builder.save()


if __name__ == "__main__":
    export()

拉tfservingdocker起服务

docker run -t --rm -p 8501:8501 \
   -v "$(pwd)/model/half_plus_ten:/models/half_plus_ten" \
   -e MODEL_NAME=half_plus_ten \
   tensorflow/serving

调用服务

curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_ten:predict

keras模型保存成pb更简单, 一行代码解决,注意要使用tensorflow.python.keras

model.save('model/keras_model/1', save_format='tf')

查看保存好的pb模型细节

saved_model_cli show --dir model/fm_item/1 --all

可以看到pb模型输入输出

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['item_id_hash_pos'] tensor_info:
        dtype: DT_INT32
        shape: (-1, 1)
        name: serving_default_item_id_hash_pos:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['lambda_1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 8)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

使用grpc调用服务:

from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc
import tensorflow as tf

import numpy as np

def request_server(server_url):
    '''
    用于向TensorFlow Serving服务请求推理结果的函数。
    :param img_resized: 经过预处理的待推理图片数组,numpy array,shape:(h, w, 3)
    :param server_url: TensorFlow Serving的地址加端口,str,如:'0.0.0.0:8500' 
    :return: 模型返回的结果数组,numpy array
    '''
    # Request.
    channel = grpc.insecure_channel(server_url)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
    request = predict_pb2.PredictRequest()
    request.model_spec.name = "half_plus_ten"  # 模型名称,启动容器命令的model_name参数
    request.model_spec.signature_name = "serving_default"  # 签名名称,刚才叫你记下来的
    # "input_1"是你导出模型时设置的输入名称,刚才叫你记下来的
    x_data = [[3428],[968],[3],[2]]
    request.inputs["item_id_hash_pos"].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
    response = stub.Predict(request, 5.0)  # 5 secs timeout
    print(response.outputs["lambda_1"])
    return np.asarray(response.outputs["lambda_1"].float_val) # fc2为输出名称,刚才叫你记下来的


if __name__ == "__main__":
    print(request_server("0.0.0.0:8500"))
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • TF-serving介绍 TensorFlow Serving是google提供的一种生产环境部署方案,一般来说在...
    612twilight阅读 1,805评论 2 1
  • 1 tf serving简介 2 保存模型为tf serving需要的pb格式文件 保存含有自定义签名信息的模型【...
    georgeguo阅读 2,689评论 0 1
  • TF 1.0到2.0迁移 在TensorFlow 2.0中,仍然可以运行未经修改的1.x代码(contrib除外)...
    AnuoF阅读 10,778评论 0 4
  • 推荐指数: 6.0 书籍主旨关键词:特权、焦点、注意力、语言联想、情景联想 观点: 1.统计学现在叫数据分析,社会...
    Jenaral阅读 5,742评论 0 5
  • 昨天,在回家的路上,坐在车里悠哉悠哉地看着三毛的《撒哈拉沙漠的故事》,我被里面的内容深深吸引住了,尽管上学时...
    夜阑晓语阅读 3,810评论 2 9