Java加载Tensorflow2保存的模型

参考:
使用spark-scala调用tensorflow2.0训练好的模型

1. 使用TF2训练并保存模型:

import tensorflow as tf
from tensorflow.keras import models,layers,optimizers

## 样本数量
n = 800

## 生成测试用数据集
X = tf.random.uniform([n,2],minval=-10,maxval=10)
w0 = tf.constant([[2.0],[-1.0]])
b0 = tf.constant(3.0)

Y = X@w0 + b0 + tf.random.normal([n,1],mean = 0.0,stddev= 2.0)  # @表示矩阵乘法,增加正态扰动

## 建立模型
tf.keras.backend.clear_session()
inputs = layers.Input(shape = (2,),name ="inputs") #设置输入名字为inputs
outputs = layers.Dense(1, name = "outputs")(inputs) #设置输出名字为outputs
linear = models.Model(inputs = inputs,outputs = outputs)
linear.summary()

## 使用fit方法进行训练
linear.compile(optimizer="rmsprop",loss="mse",metrics=["mae"])
linear.fit(X,Y,batch_size = 8,epochs = 100)

tf.print("w = ",linear.layers[1].kernel)
tf.print("b = ",linear.layers[1].bias)

## 将模型保存成pb格式文件
export_path = "/your_path/tf2_linear"
linear.save(export_path, save_format="tf")

保存模型目录:

 ~/demo/your_path  tree
.
└── tf2_linear
    ├── assets
    ├── saved_model.pb
    └── variables
        ├── variables.data-00000-of-00001
        └── variables.index

3 directories, 3 files

2. 使用Java加载模型并预测

查看模型细节(Java加载模型及预测需要)

 ~/demo/your_path  saved_model_cli  show --dir  ./tf2_linear --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: serving_default_inputs:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['outputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

maven依赖

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.15.0</version>
</dependency>
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>fastjson</artifactId>
    <version>1.2.73</version>
</dependency>

Java代码

package com.ml.demo.tf;

import com.alibaba.fastjson.JSON;
import org.tensorflow.*;

public class PredictNN {
    public static void main(String args[]){
        Session session = SavedModelBundle.load("/your_path/tf2_linear",
                "serve").session();

        float[][] input = {
            {2.6327686f, -9.201903f},
            {    -1.3209248f, 8.569574f},
            {    -5.6642127f, 3.3681698f},
            {    9.604832f, 5.9664965f},
            {    -0.8812313f, -6.76733f}
        };
        System.out.println("input: \n" + JSON.toJSONString(input));
        Tensor inputTensor = Tensor.create(input);
        Tensor resultTensor = session.runner()
                .feed("serving_default_inputs:0", inputTensor)
                .fetch("StatefulPartitionedCall:0")
                .run().get(0);

        float[][] result = new float[input.length][1];
        resultTensor.copyTo(result);
        System.out.println("result: \n" + JSON.toJSONString(result));
        session.close();
    }
}
输出日志
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容