TensorFlow Lite实战——在iOS上部署中文文本分类模型

前言

本文所使用的分类模型来自于CNN-RNN中文文本分类,基于TensorFlow,感谢开源。

最近一段时间需要用到中文文本分类这样一个功能,于是我马上想到了Create ML,但是经过自己的尝试以后发现Create ML并不支持中文的文本分类(不信可以自己试试)。

最近发现有道词典有离线翻译这样一个功能,我猜测这应该就是把模型下载到本地使用了,这么一看模型部署到移动端理论上是可行的。但各个深度学习框架我只了解过tensorflow,于是在有这样一个需求之下,我又回到了tensorflow这个大坑,去年年底说我这辈子都不会再用tensorflow了,没想到真香了。

实际上tensorflow所训练的模型是放在后端最合适,但由于我不想给APP维护一个健壮的后端,所以执着于把模型部署到移动端。这个是Demo

言归正传,从头部署一个模型我可以归纳出几个步骤

  1. 训练并测试模型,将模型保存为ckpt格式
  2. 将ckpt模型固化转成pb模型
  3. 通过TensorFlow Lite提供的方法将pb模型转换为tflite模型
  4. 使用cocoapods的方式引入TensorFlow Lite,并把模型导入工程
  5. 封装调用模型逻辑,进行文本分类

注意: 本篇博客仅根据上方的开源工程进行部署,其他的网络结构还需要具体问题具体分析。

大致分类原理

如果想要从头部署一遍,一定要对tensorflow有一定了解,因为不读懂工程的源码意思是基本上无法往下流程做的。

这个工程把每一个文本中的字符映射成一个个数字(id),通过一系列玄学操作,得到一个一维数组,其中前10个就是我们要关注的值,因为标签只有10个。

数据处理

我们需要了解数据处理的方式即输入和输出,这样我们才能编写代码在iOS APP中进行预测。

输入

这个开源工程中会把每一个字符(汉字)映射成一个id,这个id来自于数据集中的行,意思就是第一行对应的字符id就是0,第二行对应的是1,以此类推。这样我们就获得了一个id的数组。并且这个id数组需要处理成一个固定长度,本文在iOS中处理方式为不足则数组后面添0,多余则移除数组末尾。

输出

输出的是一个数组,数量会超过10个,但因为数据集中的分类只有10个,所以我们只需要关注这个数组的前10个即可。这前10个数组对应的下标就是标签数组中的下标,数组的值就是预测的概率。所以输出的数组0-10的下标就对应了标签数组中0-10具体分类的可能性。

部署

训练模型

本文使用开源工程中的CNN网络,因为TensorFlow Lite支持的operators有限,所以不是所有的TensorFlow中的operators都支持,如果出现不支持的情况就会在转换中出现类似如下的错误:

Some of the operators in the model are not supported by the standard TensorFlow Lite runtime. If you have a custom implementation for them you can disable this error with --allow_custom_ops, or by setting allow_custom_ops=True when calling tf.contrib.lite.toco_convert(). Here is a list of operators for which  you will need custom implementations: RandomUniform

这里的错误中可以发现不支持的operator是RandomUniform。查找之后发现CNN中的tf.contrib.layers.dropout不受支持,但是这个问题不大,我们可以用L2正则化去替代它防止过拟合。

下面是修改后的参考代码:

# coding: utf-8
from functools import partial

import tensorflow as tf


class TCNNConfig(object):
    """CNN配置参数"""

    embedding_dim = 64  # 词向量维度
    seq_length = 600  # 序列长度
    num_classes = 10  # 类别数
    num_filters = 256  # 卷积核数目
    kernel_size = 5  # 卷积核尺寸
    vocab_size = 5000  # 词汇表达小

    hidden_dim = 128  # 全连接层神经元

    dropout_keep_prob = 0.5  # dropout保留比例
    learning_rate = 1e-3  # 学习率

    batch_size = 64  # 每批训练大小
    num_epochs = 10  # 总迭代轮次

    print_per_batch = 100  # 每多少轮输出一次结果
    save_per_batch = 10  # 每多少轮存入tensorboard

    scale = 0.01


class TextCNN(object):
    """文本分类,CNN模型"""

    def __init__(self, config):
        self.config = config

        # 三个待输入的数据
        self.input_x = tf.placeholder(tf.int32, [None, self.config.seq_length], name='input_x')
        self.input_y = tf.placeholder(tf.float32, [None, self.config.num_classes], name='input_y')
        self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')

        self.cnn()

    def cnn(self):
        """CNN模型"""
        my_dense_layer = partial(
            tf.layers.dense, activation=tf.nn.relu,
            # 在这里传入了L2正则化函数,并在函数中传入正则化系数。
            kernel_regularizer=tf.contrib.layers.l2_regularizer(self.config.scale)
        )
        # 词向量映射
        with tf.device('/cpu:0'):
            embedding = tf.get_variable('embedding', [self.config.vocab_size, self.config.embedding_dim])
            embedding_inputs = tf.nn.embedding_lookup(embedding, self.input_x)

        with tf.name_scope("cnn"):
            # CNN layer
            conv = tf.layers.conv1d(embedding_inputs, self.config.num_filters, self.config.kernel_size, name='conv')
            # global max pooling layer
            gmp = tf.reduce_max(conv, reduction_indices=[1], name='gmp')

        with tf.name_scope("score"):
            # 全连接层
            fc = my_dense_layer(gmp, self.config.hidden_dim, name='fc1')
            # fc = tf.layers.dense(gmp, self.config.hidden_dim, name='fc1')
            # fc = tf.contrib.layers.dropout(fc, self.keep_prob)
            # fc = tf.nn.relu(fc)

            # 分类器
            self.logits = tf.layers.dense(fc, self.config.num_classes, name='fc2')
            self.y_pred_cls = tf.argmax(tf.nn.softmax(self.logits), 1)  # 预测类别

        with tf.name_scope("optimize"):
            # 损失函数,交叉熵
            cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.input_y)
            reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
            self.loss = tf.add_n([tf.reduce_mean(cross_entropy)] + reg_losses)
            # self.loss = tf.reduce_mean(cross_entropy)
            # 优化器
            self.optim = tf.train.AdamOptimizer(learning_rate=self.config.learning_rate).minimize(self.loss)

        with tf.name_scope("accuracy"):
            # 准确率
            correct_pred = tf.equal(tf.argmax(self.input_y, 1), self.y_pred_cls)
            self.acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

接下来在run_cnn.py中经过训练就能获得如下ckpt模型了:


ckpt模型

将ckpt模型固化转成pb模型

在固化模型这一个环节,你需要通读这个开源工程才行,不然你肯定不了解它的网络结构以及它的输入和输出。这也是对iOS开发者非常不友好的地方。

通过源码我们可以得知TextCNN这个类中的self.logits这个属性就是我们需要关注的输出,所以我们可以通过下面这段代码打印出tensor,然后找到我们需要的输出的name

ops = sess.graph.get_operations()
        for op in ops:
            print(op)

这里我们需要的name是

output_node_names = "score/fc2/BiasAdd"

参考源码:

def freeze_graph(input_checkpoint):
    """
    :param input_checkpoint:
    :return:
    """
    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径

    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
    output_node_names = "score/fc2/BiasAdd"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)  # 恢复图并得到数据
        output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=sess.graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(",")
        )  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出

input_checkpoint为你的ckpt模型路径

将pb模型转换为tflite模型

下面是from_frozen_graph方法的注解。这里我就要吐槽一下了,TensorFlow Lite的文档未免太敷衍了,说好的传入参数是一个[tensor],结果老报错,在打断点调试了它们库的源码情况下发现竟然要求的是传入tensor的name???

from_frozen_graph方法注解

这个只要没有出现operator不支持的情况就很简单,直接上源码就完了

def convert_to_tflite():
    input_tensors = [
        "input_x"
    ]
    output_tensors = [
        "score/fc2/BiasAdd"
    ]
    converter = tf.lite.TFLiteConverter.from_frozen_graph(
        output_graph,
        input_tensors,
        output_tensors)
    converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
                            tf.lite.OpsSet.SELECT_TF_OPS]
    tflite_model = converter.convert()
    open(output_tflite_model, "wb").write(tflite_model)

其中input_x是输入的name

使用cocoapods的方式引入TensorFlow Lite

TensorFlow Lite有好几个库,原生的需要写C++,在一顿操作之下我放弃了,完全看不懂tensor的输入嘛。还有OC封装的以及swift封装的。因为我的工程是swift写的,所以我直接用swift的TensorFlow Lite库

按照他们的README

pod 'TensorFlowLiteSwift'
import TensorFlowLite

就引入了,这一点就很友好了,比什么直接编译TensorFlow到iOS工程里那是简单的不能再简单了。

封装调用模型逻辑,进行文本分类

在喂数据进行预测时我们也要按照开源工程里喂数据的方式进行一番操作。调用的逻辑我们可以参考官方Example

导入模型

我们需要导入模型、分类和字符id,这在本文的前言中提供的demo中有体现。

必须导入的东西

初始化Interpreter

private init() {
        let options = InterpreterOptions()
        do {
            // Create the `Interpreter`.
            let modelPath = Bundle.init(for: TextClassifier.self).path(forResource: "model", ofType: "tflite")!
            interpreter = try Interpreter(modelPath: modelPath, options: options)
            // Allocate memory for the model's input `Tensor`s.
            try interpreter.allocateTensors()
        } catch {
            print("Failed to create the interpreter with error: \(error.localizedDescription)")
        }
    }

加载标签、id以及将字符转换为id

private func loadLabels() {
        if let path = Bundle.init(for: TextClassifier.self).path(forResource: "labels", ofType: "txt") {
            let fileManager = FileManager.default
            let txtData = fileManager.contents(atPath: path)!
            let content = String.init(data: txtData, encoding: .utf8)
            let rowArray = content?.split(separator: "\n") ?? []
            for row in rowArray {
                labels.append(String(row))
            }
        }
    }
    
    private func loadTextId() {
        if let path = Bundle.init(for: TextClassifier.self).path(forResource: "text_id", ofType: "txt") {
            let fileManager = FileManager.default
            let txtData = fileManager.contents(atPath: path)!
            let content = String.init(data: txtData, encoding: .utf8)
            let rowArray = content?.split(separator: "\n") ?? []
            var i = 0
            for row in rowArray {
                textIdInfo[String(row)] = i
                i += 1
            }
        }
    }
    
    private func transformTextToId(_ text: String) -> [Int] {
        var idArray: [Int] = []
        for str in text {
            idArray.append(textIdInfo[String(str)]!)
        }
        //根据python工程中的输入设置,超出截取前面,不足后面补0
        while idArray.count < 2400 {
            idArray.append(0)
        }
        while idArray.count > 2400 {
            idArray.removeLast()
        }
        return idArray
    }

进行预测

public func runModel(with text: String, closure: @escaping(InferenceReslutClosure)) {
        DispatchQueue.global().async {
            let idArray = self.transformTextToId(text)
            let outputTensor: Tensor
            do {
                _ = try self.interpreter.input(at: 0)
                let idData = Data.init(bytes: idArray, count: idArray.count)
                try self.interpreter.copy(idData, toInputAt: 0)
                try self.interpreter.invoke()
                outputTensor = try self.interpreter.output(at: 0)
            } catch {
                print("An error occurred while entering data: \(error.localizedDescription)")
                return
            }
            let results: [Float]
            switch outputTensor.dataType {
            case .uInt8:
                guard let quantization = outputTensor.quantizationParameters else {
                    print("No results returned because the quantization values for the output tensor are nil.")
                    return
                }
                let quantizedResults = [UInt8](outputTensor.data)
                results = quantizedResults.map {
                    quantization.scale * Float(Int($0) - quantization.zeroPoint)
                }
            case .float32:
                results = outputTensor.data.withUnsafeBytes( { (ptr: UnsafeRawBufferPointer) in
                    [Float32](UnsafeBufferPointer.init(start: ptr.baseAddress?.assumingMemoryBound(to: Float32.self), count: ptr.count))
                })
            default:
                print("Output tensor data type \(outputTensor.dataType) is unsupported for this app.")
                return
            }
            let resultArray = self.getTopN(results: results)
            DispatchQueue.main.async {
                closure(resultArray)
            }
        }
    }

首先我们需要把[Int]类型转换为Data类型提供给interpreter,可以如下方法转换

let idData = Data.init(bytes: idArray, count: idArray.count)

invoke()方法为调用模型进行预测

我们拿到输出outputTensor以后,它的dataType中的float32类型就是我们需要的输出,这是因为在开源工程中的输出就是float32类型。这里我们需要用swift的指针去把Data类型换为[Float]类型,如下:

results = outputTensor.data.withUnsafeBytes( { (ptr: UnsafeRawBufferPointer) in
                    [Float32](UnsafeBufferPointer.init(start: ptr.baseAddress?.assumingMemoryBound(to: Float32.self), count: ptr.count))
                })

至于上面那个.UInt8我没有搞懂是什么意思,但我想我的输出都是float32类型,所以应该是不会走上面那个case。

最后我们通过getTopN方法取到前3个可能性最大的标签(预测值)

private func getTopN(results: [Float]) -> [Inference] {
        //创建元组数组 [(labelIndex: Int, confidence: Float)]
        let zippedResults = zip(labels.indices, results)
        //从大到小排序并选出前resultCount个(根据python工程中的训练,只取前10个,因为分类只有10个)
        let sortedResults = zippedResults.sorted { $0.1 > $1.1 }.prefix(resultCount)
        //返回前resultCount对应的标签以及预测值
        return sortedResults.map { result in Inference.init(confidence: result.1, label: labels[result.0]) }
    }

这里取的逻辑就像上述所说的,我们只关注输出一维数组的前10个元素,然后给他们排个序取最大三个值,这三个值所在的下标直接在标签数组中取值就能获得对应的预测分类

最后

博客只是一个预览,详细的逻辑还是需要直接看Demo

参考

CNN-RNN中文文本分类,基于TensorFlow
TensorFlow for Poets 2: TFLite iOS
【IOS/Android】TensorflowLite移动端部署
TensorFlow Lite Swift Example
Tensorflow Convert pb file to TFLITE using python

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,163评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,301评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,089评论 0 352
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,093评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,110评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,079评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,005评论 3 417
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,840评论 0 273
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,278评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,497评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,667评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,394评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,980评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,628评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,796评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,649评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,548评论 2 352

推荐阅读更多精彩内容