在TensorFlow(Python, Java)环境下使用Keras模型

Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。有时候我们在使用keras设计好模型后,需要在其他平台进行运行,这时候我们就需要将keras h5 model转换为TensorFlow pb model,因为keras只是一个Python的高级库,而TensorFlow能够支持多平台的运行。

环境

  • Python 3.6
  • Keras 2.2.2
  • Tensorflow-gpu 1.8.0

Keras to Tensorflow

测试数据:

from keras.datasets import imdb

def get_data():
    max_features = 20000
    # cut texts after this number of words
    # (among top max_features most common words)
    maxlen = 100

    print('Loading data...')
    (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
    print(x_train.shape, 'train sequences')
    print(x_test.shape, 'test sequences')

    print('Pad sequences (samples x time)')
    x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
    x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
    print('x_train shape:', x_train.shape)
    print('x_test shape:', x_test.shape)
    y_train = np.array(y_train)
    y_test = np.array(y_test)

    return x_train, x_test, y_train, y_test

生成一个keras模型进行训练,获得模型和对应的权重文件:

from keras.layers import Conv1D, GlobalMaxPooling1D, Embedding, Dense, Dropout
from keras.datasets import imdb
from keras.preprocessing import sequence
from keras.models import Sequential


def gen_keras_model(x_train, x_test, y_train, y_test, train=False):
    inp = Input(shape=(100,))
    x = Embedding(20000, 50)(inp)
    x = Dropout(0.2)(x)
    x = Conv1D(250, 3, padding='valid', activation='relu', strides=1)(x)
    x = GlobalMaxPooling1D()(x)
    x = Dense(250, activation='relu')(x)
    x = Dropout(0.2)(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=inp, outputs=x)

    if train:
        model.compile(loss='binary_crossentropy',
                      optimizer='adam',
                      metrics=['accuracy'])

        model.fit(x_train, y_train,
                  batch_size=32,
                  epochs=2,
                  validation_data=(x_test, y_test))

        model.save_weights('model.h5')

    return model


if __name__ == '__main__':
    x_train, x_test, y_train, y_test = get_data()
    model = gen_keras_model(x_train, x_test, y_train, y_test, True)

下面的函数将keras model转换为Tensorflow pb文件:

  1. 首先构建一个Session与空的计算图,将这个计算图设置为默认的计算图。
  2. 获取keras model的输出节点,将这个输出节点与节点名在这个计算图中进行绑定。
  3. 使用convert_variables_to_constants函数保存数输出节点,函数会自动推导计算图并将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。
  4. 最后向指定目录写入pb文件。

如果你的graph使用了Keras的learning phase(在训练和测试中行为不同),你首先要做的事就是在graph中硬编码你的工作模式(设为0,即测试模式),该工作通过:1)使用Keras的后端注册一个learning phase常量,2)重新构建模型,来完成。

import tensorflow as tf
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io


def export_graph(model, export_path):
    input_names = model.input_names

    if not tf.gfile.Exists(export_path):
        tf.gfile.MakeDirs(export_path)

    with K.get_session() as sess:
        init_graph = sess.graph
        with init_graph.as_default():
            out_nodes = []

            for i in range(len(model.outputs)):
                out_nodes.append("output_" + str(i + 1))
                tf.identity(model.output[i], "output_" + str(i + 1))

            init_graph = sess.graph.as_graph_def()
            main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
            graph_io.write_graph(main_graph, export_path, name='model.pb', as_text=False)

    return input_names, out_nodes


if __name__ == '__main__':
    x_train, x_test, y_train, y_test = get_data()

    learning_phase = 0
    K.set_learning_phase(learning_phase)
    model = gen_keras_model(x_train, x_test, y_train, y_test, learning_phase)
    model.load_weights('model.h5')

    input_names, output_names = export_graph(model, 'model')

在Python Tensorflow环境下进行测试

  1. 首先在Session与Graph中读入pb文件,构建计算图。
  2. 然后根据输入张量与输出张量的张量名来获取到对应的张量,这里一定要加上:0。比如input_1:0是张量的名称而input_1表示的是节点的名称。
  3. 最后使用常规的Tensorflow操作来运行模型。
import numpy as np
import tensorflow as tf

from sklearn.metrics import accuracy_score


def run_graph(pb_file_path, input_name, output_name, x_test, y_test):
    tf.reset_default_graph()

    sess = tf.Session()
    with tf.gfile.FastGFile(pb_file_path, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')

    #输入
    input_x = sess.graph.get_tensor_by_name('{}:0'.format(input_name))
    #输出
    op = sess.graph.get_tensor_by_name('{}:0'.format(output_name))
    #预测结果
    pred = []
    for x in x_test:
        res = sess.run(op, {input_x: x.reshape(1, -1)})
        pred.append(res[0])
    
    pred = np.array([1 if p > 0.5 else 0 for p in pred])
    
    acc = accuracy_score(y_test, pred)

    print('Accuracy:{}'.format(acc))


if __name__ == '__main__':
    x_train, x_test, y_train, y_test = get_data()

    learning_phase = 0
    K.set_learning_phase(learning_phase)
    model = gen_keras_model(x_train, x_test, y_train, y_test, learning_phase)
    model.load_weights('model.h5')

    input_names, output_names = export_graph(model, 'model')

    pred = run_graph('model\model.pb', input_names[0], output_names[0], x_test, y_test)

输出如下:

Using TensorFlow backend.
Loading data...
(25000,) train sequences
(25000,) test sequences
Pad sequences (samples x time)
x_train shape: (25000, 100)
x_test shape: (25000, 100)
INFO:tensorflow:Froze 7 variables.
Converted 7 variables to const ops.
Accuracy:0.84388

在JavaTensorflow环境下进行测试

在 Windows 上安装按照以下步骤在 Windows 上安装适用于 Java 的 TensorFlow:

  1. 下载 libtensorflow.jar,这是 TensorFlow Java 归档 (JAR)。
  2. 下载 Windows 上适用于 Java 的 TensorFlow 对应的 Java 原生接口 (JNI) 文件。
  3. 解压缩该 .zip 文件。
  4. 配置到IDEA的External Libraries中。
setting

在Java中使用PB文件的代码如下,我们随机生成一个数组作为输入的张量进行测试。整个流程与Python下类似,需要注意的是生成输入张量时数组类型需要定义为float类型,不然会出现以下错误:

Exception in thread "main" java.lang.IllegalArgumentException: Expects arg[0] to be float but double is provided

Java下的测试代码:

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;

public class TFTest {
    public static void main(String[] args) throws IOException {
        String path = "E:\\Documents\\Desktop\\code\\glu\\model\\model.pb";
        float[][] input = new float[1][100];

        for (int i=0; i < 100; i++){
            input[0][i] = (float) (Math.random() * 100);
        }

        try (Graph graph = new Graph()){
            graph.importGraphDef(Files.readAllBytes(Paths.get(path)));

            try (Session sess = new Session(graph)){

                try (Tensor x = Tensor.create(input); 
                     Tensor y = sess.runner().feed("input_1", x).fetch("output_1").run().get(0)){

                    float[] res = (float[]) y.copyTo(new float[1]);
                    System.out.println(Arrays.toString(y.shape()));
                    System.out.println(Arrays.toString(res));
                }
            }
        }
    }
}

输出结果如下:

[1]
[0.088513985]

Process finished with exit code 0
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 218,036评论 6 506
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,046评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 164,411评论 0 354
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,622评论 1 293
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,661评论 6 392
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,521评论 1 304
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,288评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,200评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,644评论 1 314
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,837评论 3 336
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,953评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,673评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,281评论 3 329
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,889评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,011评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,119评论 3 370
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,901评论 2 355

推荐阅读更多精彩内容