Java调用Keras、Tensorflow模型

实现python离线训练模型,Java在线预测部署。查看原文

目前深度学习主流使用python训练自己的模型,有非常多的框架提供了能快速搭建神经网络的功能,其中Keras提供了high-level的语法,底层可以使用tensorflow或者theano。

但是有很多公司后台应用是用Java开发的,如果用python提供HTTP接口,对业务延迟要求比较高的话,仍然会有一定得延迟,所以能不能使用Java调用模型,python可以离线的训练模型?(tensorflow也提供了成熟的部署方案TensorFlow Serving

手头上有一个用Keras训练的模型,网上关于Java调用Keras模型的资料不是很多,而且大部分是重复的,并且也没有讲的很详细。大致有两种方案,一种是基于Java的深度学习库导入Keras模型实现,另外一种是用tensorflow提供的Java接口调用。

Deeplearning4J

Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AIAI to business environments for use on distributed GPUs and CPUs.

Deeplearning4j目前支持导入Keras训练的模型,并且提供了类似python中numpy的一些功能,更方便地处理结构化的数据。遗憾的是,Deeplearning4j现在只覆盖了Keras <2.0版本的大部分Layer,如果你是用Keras 2.0以上的版本,在导入模型的时候可能会报错。

了解更多:
Keras Model Import: Supported Features
Importing Models From Keras to Deeplearning4j

Tensorflow

文档,Java的文档很少,不过调用模型的过程也很简单。采用这种方式调用模型需要先将Keras导出的模型转成tensorflow的protobuf协议的模型。

1、Keras的h5模型转为pb模型

在Keras中使用model.save(model.h5)保存当前模型为HDF5格式的文件中。
Keras的后端框架使用的是tensorflow,所以先把模型导出为pb模型。在Java中只需要调用模型进行预测,所以将当前的graph中的Variable全部变成Constant,并且使用训练后的weight。以下是freeze graph的代码:

    def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
        """
        :param session: 需要转换的tensorflow的session
        :param keep_var_names:需要保留的variable,默认全部转换constant
        :param output_names:output的名字
        :param clear_devices:是否移除设备指令以获得更好的可移植性
        :return:
        """
        from tensorflow.python.framework.graph_util import convert_variables_to_constants
        graph = session.graph
        with graph.as_default():
            freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
            output_names = output_names or []
            # 如果指定了output名字,则复制一个新的Tensor,并且以指定的名字命名
            if len(output_names) > 0:
                for i in range(output_names):
                    # 当前graph中复制一个新的Tensor,指定名字
                    tf.identity(model.model.outputs[i], name=output_names[i])
            output_names += [v.op.name for v in tf.global_variables()]
            input_graph_def = graph.as_graph_def()
            if clear_devices:
                for node in input_graph_def.node:
                    node.device = ""
            frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                          output_names, freeze_var_names)
            return frozen_graph

该方法可以将tensor为Variable的graph全部转为constant并且使用训练后的weight。注意output_name比较重要,后面Java调用模型的时候会用到。

在Keras中,模型是这么定义的:

    def create_model(self):
        input_tensor = Input(shape=(self.maxlen,), name="input")
        x = Embedding(len(self.text2id) + 1, 200)(input_tensor)
        x = Bidirectional(LSTM(128))(x)
        x = Dense(256, activation="relu")(x)
        x = Dropout(self.dropout)(x)
        x = Dense(len(self.id2class), activation='softmax', name="output_softmax")(x)
        model = Model(inputs=input_tensor, outputs=x)
        model.compile(loss='categorical_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])

下面的代码可以查看定义好的Keras模型的输入、输出的name,这对之后Java调用有帮助。

print(model.input.op.name)
print(model.output.op.name)

训练好Keras模型后,转换为pb模型:

from keras import backend as K
import tensorflow as tf

model.load_model("model.h5")
print(model.input.op.name)
print(model.output.op.name)
# 自定义output_names
frozen_graph = freeze_session(K.get_session(), output_names=["output"])
tf.train.write_graph(frozen_graph, "./", "model.pb", as_text=False)

### 输出:
# input
# output_softmax/Softmax
# 如果不自定义output_name,则生成的pb模型的output_name为output_softmax/Softmax,如果自定义则以自定义名为output_name

运行之后会生成model.pb的模型,这将是之后调用的模型。

2、Java调用

新建一个maven项目,pom里面导入tensorflow包:

<dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow</artifactId>
            <version>1.6.0</version>
</dependency>

核心代码:

public void predict() throws Exception {
        try (Graph graph = new Graph()) {
            graph.importGraphDef(Files.readAllBytes(Paths.get(
                    "path/to/model.pb"
            )));
            try (Session sess = new Session(graph)) {
                // 自己构造一个输入
                float[][] input = {{56, 632, 675, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
                try (Tensor x = Tensor.create(input);
                    // input是输入的name,output是输出的name
                    Tensor y = sess.runner().feed("input", x).fetch("output").run().get(0)) {
                    float[][] result = new float[1][y.shape[1]];
                    y.copyTo(result);
                    System.out.println(Arrays.toString(y.shape()));
                    System.out.println(Arrays.toString(result[0]));
                }
            }
        }
    }

Graph和Tensor对象都是需要通过close()方法显式地释放占用的资源,代码中使用了try-with-resources的方法实现的。

至此,已经可以实现Keras离线训练,Java在线预测的功能。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,047评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,807评论 3 386
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,501评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,839评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,951评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,117评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,188评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,929评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,372评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,679评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,837评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,536评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,168评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,886评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,129评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,665评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,739评论 2 351

推荐阅读更多精彩内容