TensorFlow Serving使用指南

一、TensorFlow Serving简介

TensorFlow Serving是GOOGLE开源的一个服务系统,适用于部署机器学习模型,灵活、性能高、可用于生产环境。 TensorFlow Serving可以轻松部署新算法和实验,同时保持相同的服务器架构和API,它具有以下特性:

  • 支持模型版本控制和回滚
  • 支持并发,实现高吞吐量
  • 开箱即用,并且可定制化
  • 支持多模型服务
  • 支持批处理
  • 支持热更新
  • 支持分布式模型
  • 易于使用的inference api
  • 为gRPC expose port 8500,为REST API expose port 8501

二、安装与测试

安装TensorFlow Serving有三种方法:docker,二进制,源码编译,这里只介绍通过docker安装的步骤。

  1. 安装docker,通过docker拉取tensorflow serving镜像
# 此处拉取的是cpu serving镜像,如果要使用gpu,需要安装对应的docker和对应的gpu serving镜像
docker pull tensorflow/serving
  1. 拉取源码,部署源码中的half_plus_two模型测试serving是否可用
git clone https://github.com/tensorflow/serving
cd serving

TESTDATA="$(pwd)/serving/tensorflow_serving/servables/tensorflow/testdata"
# 设置端口转发,以下两条命令都可以启动服务
docker run -dt -p 8501:8501 -v "$TESTDATA/saved_model_half_plus_two_cpu:/models/half_plus_two" -e MODEL_NAME=half_plus_two tensorflow/serving
docker run -d -p 8501:8501 --mounttype=bind,source=$TESTDATA/saved_model_half_plus_two_cpu/,target=/models/half_plus_two -e MODEL_NAME=half_plus_two -t --name testserver tensorflow/serving

# 在服务器本机测试模型是否正常工作,这里需要注意,源码中half_plus_two的模型版本是00000123,但在访问时也必须输入v1而不是v000000123
curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:6005/v1/models/half_plus_two:predict
# 得到{"predictions": [2.5, 3.0, 4.5]}这个结果说明模型部署成功
  1. 命令说明
docker常用命令
# 启动/停止容器
docker start/stop $container_id或$container_name
# 查看运行容器
docker ps
# 查看全部容器
docker ps -a
# 删除指定容器
docker rm $container_id或$container_name
# 查看运行容器的日志
docker logs -f $container_id或$container_name
# docker 删除镜像
docker rmi image:tag # 例如docker rmi tensorflow/serving:latest
docker启动服务时参数
--mount:   表示要进行挂载
source:    指定要运行部署的模型地址, 也就是挂载的源,这个是在宿主机上的servable模型目录(pb格式模型而不是checkpoint模型)
target:     这个是要挂载的目标位置,也就是挂载到docker容器中的哪个位置,这是docker容器中的目录,模型默认挂在/models/目录下,如果改变路径会出现找不到model的错误
-t:         指定的是挂载到哪个容器
-d:         后台运行
-p:         指定主机到docker容器的端口映射
-e:         环境变量
-v:         docker数据卷
--name:     指定容器name,后续使用比用container_id更方便

三、模型格式转换

我们平时使用tf.Saver()保存的模型是checkpoint格式的,但是在TensorFlow Serving中一个servable的模型目录中是一个pb格式文件和一个名为variables的目录,因此需要在模型保存时就保存好可部署的模型格式,或者将已经训练好的checkpoint转换为servable format。

-ckpt_model
        -checkpoint
        -***.ckpt.data-00000-of-00001
        -***.ckpt.index
        -***.ckpt.meta
#转换为
-servable_model
        -version
                -saved_model.pb
                -variables

以下以命名实体识别空洞卷积模型为例,展示如何转换模型格式。

#coding:utf-8
import sys, os, io
import tensorflow as tf

def restore_and_save(input_checkpoint, export_path_base):
    checkpoint_file = tf.train.latest_checkpoint(input_checkpoint)
    graph = tf.Graph()

    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            # 载入保存好的meta graph,恢复图中变量,通过SavedModelBuilder保存可部署的模型
            saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)
            print (graph.get_name_scope())

            export_path_base = export_path_base
            export_path = os.path.join(
                tf.compat.as_bytes(export_path_base),
                tf.compat.as_bytes(str(count)))
            print('Exporting trained model to', export_path)
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            # 建立签名映射,需要包括计算图中的placeholder(ChatInputs, SegInputs, Dropout)和我们需要的结果(project/logits,crf_loss/transitions)
            """
            build_tensor_info:建立一个基于提供的参数构造的TensorInfo protocol buffer,
            输入:tensorflow graph中的tensor;
            输出:基于提供的参数(tensor)构建的包含TensorInfo的protocol buffer
                        get_operation_by_name:通过name获取checkpoint中保存的变量,能够进行这一步的前提是在模型保存的时候给对应的变量赋予name
            """

            char_inputs =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("ChatInputs").outputs[0])
            seg_inputs =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("SegInputs").outputs[0])
            dropout =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("Dropout").outputs[0])
            logits =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("project/logits").outputs[0])

            transition_params =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("crf_loss/transitions").outputs[0])

            """
            signature_constants:SavedModel保存和恢复操作的签名常量。
            在序列标注的任务中,这里的method_name是"tensorflow/serving/predict"
            """
            # 定义模型的输入输出,建立调用接口与tensor签名之间的映射
            labeling_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        "charinputs":
                            char_inputs,
                        "dropout":
                            dropout,
                        "seginputs":
                            seg_inputs,
                    },
                    outputs={
                        "logits":
                            logits,
                        "transitions":
                            transition_params
                    },
                    method_name="tensorflow/serving/predict"))

            """
            tf.group : 创建一个将多个操作分组的操作,返回一个可以执行所有输入的操作
            """
            legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

            """
            add_meta_graph_and_variables:建立一个Saver来保存session中的变量,
                                          输出对应的原图的定义,这个函数假设保存的变量已经被初始化;
                                          对于一个SavedModelBuilder,这个API必须被调用一次来保存meta graph;
                                          对于后面添加的图结构,可以使用函数 add_meta_graph()来进行添加
            """
            # 建立模型名称与模型签名之间的映射
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                # 保存模型的方法名,与客户端的request.model_spec.signature_name对应
                signature_def_map={
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                       labeling_signature},
                legacy_init_op=legacy_init_op)

            builder.save()
            print("Build Done")

### 测试模型转换
tf.flags.DEFINE_string("ckpt_path",     "source_ckpt/IDCNN",             "path of source checkpoints")
tf.flags.DEFINE_string("pb_path",       "servable-models/IDCNN",             "path of servable models")
tf.flags.DEFINE_integer("version",      1,              "the number of model version")
tf.flags.DEFINE_string("classes",       'LOC',          "multi-models to be converted")
FLAGS = tf.flags.FLAGS

classes = FLAGS.classes
input_checkpoint = FLAGS.ckpt_path + "/" + classes
model_path = FLAGS.pb_path + '/' + classes

# 版本号控制
count = FLAGS.version
modify = False
if not os.path.exists(model_path):
    os.mkdir(model_path)
else:
    for v in os.listdir(model_path):
        print(type(v), v)
        if int(v) >= count:
            count = int(v)
            modify = True
    if modify:
        count += 1

# 模型格式转换
restore_and_save(input_checkpoint, model_path)

四、多模型部署

在现实情况里,一个任务可能需要用到多个模型,例如命名实体识别我训练了多个模型,对每个句子都需要汇总所有模型的结果,这时就需要用到多模型部署。在安装与测试一节中,我们介绍的是一个服务中只部署一个模型,这一节介绍下如何通过TF-Serving进行多模型部署。
多模型部署时,无法在命令行中指定MODEL_NAME了,需要编写一个如下的json配置文件,这里取名为model.config。

model_config_list: {
    config: {
        name: "model1",
        base_path: "/models/model1",
        model_platform: "tensorflow",
        model_version_policy: {
           all: {}
    },
    config: {
        name: "model2",
        base_path: "/models/model2",
        model_platform: "tensorflow",
        model_version_policy: {
           latest: {
               num_versions: 1
           }
        }
    },
    config: {
        name: "model3",
        base_path: "/models/model3",
        model_platform: "tensorflow",
        model_version_policy: {
           specific: {
               versions: 1
           }
        }
    }
}

model_version_policy可以删除,默认是部署最新版本的模型,如果想要部署指定版本或者全部版本,需要单独设定。

启动多模型服务命令时,需要将所有模型和配置文件逐个绑定,然后指定配置文件路径

sudo docker run -d -p 8500:8500 --mounttype=bind,source=/path/to/source_models/model1/,target=/models/model1 --mounttype=bind,source=/path/to/source_models/model2/,target=/models/model2 --mounttype=bind,source=/path/to/source_models/model3/,target=/models/model3 --mounttype=bind,source=/path/to/source_models/model.config,target=/models/model.config -t --name ner tensorflow/serving --model_config_file=/models/model.config

五、模型调用

在客户端调用模型可以安装tf-serving提供的包

pip install tensorflow-serving-api
# python3 需要安装tensorflow-serving-api-python3
# 注意tensorflow和tensorflow serving的版本最好一致(此代码版本:tensorflow==1.4.0, tensorflow-serving-api==1.4.0)
  1. 建立连接
channel = implementations.insecure_channel("127.0.0.1", 8500)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
request = predict_pb2.PredictRequest()

# 指定启动tensorflow serving时配置的model_name和是保存模型时的方法名
request.model_spec.name = "model1"
request.model_spec.signature_name = "serving_default"
  1. 构造输入tensor
# 将文本进行预处理,得到list格式的输入数据,然后将其转为tensor后序列化
charinputs, seginputs = get_input(sentence)

request.inputs["charinputs"].ParseFromString(tf.contrib.util.make_tensor_proto(charinputs, dtype=tf.int32).SerializeToString())
request.inputs["seginputs"].ParseFromString(tf.contrib.util.make_tensor_proto(seginputs, dtype=tf.int32).SerializeToString())
request.inputs["dropout"].ParseFromString(tf.contrib.util.make_tensor_proto(1.0, dtype=tf.float32).SerializeToString())
  1. 获取模型输入结果
response = stub.Predict(request, timeout)

results = {}
for key in response.outputs:
    tensor_proto = response.outputs[key]
    results[key] =  tf.contrib.util.make_ndarray(tensor_proto)

# 从results中取所需要的结果,不一定是这两个变量哦
logits = results["logits"]
transitions = results["transitions"]

# 后处理
tags = get_output(logits, transitions)

TensorFlow模型的计算图,一般输入的类型都是张量,你需要提前把你的图像、文本或者其它数据先进行预处理,转换成张量才能输入到模型当中。而一般来说,这个数据预处理过程不会写进计算图里面,因此当你想使用TensorFlow Serving的时候,需要在客户端上写一大堆数据预处理代码,然后把张量通过gRPC发送到serving,最后接收结果。现实情况是你不可能要求每一个用户都要写一大堆预处理和后处理代码,用户只需使用简单POST一个请求,然后接收最终结果即可。因此,这些预处理和后处理代码必须由一个“中间人”来处理,这个“中间人”就是Web服务。关于此部分可以以后再详细介绍。

六、可能遇到的错误

docker: Error response from daemon: driver failed programming external connectivity on endpoint adoring_liskov (be1e57454fe716affc90cd6ba1d7fce7030f0de85f4262b5796aec3fbbafd7b9):  (iptables failed: iptables --wait -t filter -A DOCKER ! -i docker0 -o docker0 -p tcp -d 172.17.0.2 --dport 8501 -j ACCEPT: iptables: No chain/target/match by that name.
(exit status 1)).

docker daemon的bug,出现此情况需要重启docker。

找不到servable model所在路径时,需要检查source和target路径是否正确;target model设在/models/下一层,不要多层嵌套;检查版本号。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,816评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,729评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,300评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,780评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,890评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,084评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,151评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,912评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,355评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,666评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,809评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,504评论 4 334
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,150评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,882评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,121评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,628评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,724评论 2 351