tensorflow将ckpt模型转为pb模型

获取原网络中的所有节点

在训练代码中定义好图之后加入以下代码:

for node in tf.get_default_graph().as_graph_def().node:

    print(node.name)

主要是要查看最后一个节点的名字

模型转化

不再重新建图时, 使用tf.train.import_meta_graph

def freeze_graph(input_checkpoint,output_graph):

    '''

    :param input_checkpoint:ckpt模型路径

    :param output_graph: pb模型保存路径

    '''

    output_node_names = " " # 填入第一步得到的最后一个节点名

    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)

    with tf.Session() as sess:

        saver.restore(sess, input_checkpoint) #恢复图并得到数据

        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定

            sess=sess,

            input_graph_def=sess.graph_def,# 等于:sess.graph_def

            output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型

            f.write(output_graph_def.SerializeToString()) #序列化输出

        print("%d ops in the final graph." % len(output_graph_def.node)) # 统计图中总的操作节点数

或者修改前传代码,使用tf.train.Saver()

在前传代码里,restore模型

restorer = tf.train.Saver(tf.global_variables())

ckpt = tf.train.get_checkpoint_state(' ') # 填入ckpt模型所在文件夹路径

model_path = ckpt.model_checkpoint_path # 读取checkpoint文件里的第一行

with tf.Session() as sess:

    # Create a saver.

    sess.run(tf.local_variables_initializer())

    sess.run(tf.global_variables_initializer())

    try:

        restorer.restore(sess, model_path)

        print(model_path.split('/')[-1] + " restored!")

    except IOError:

        print("checkpoints not found.")

    output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定

        sess=sess,

        input_graph_def=sess.graph_def,  # 等于:sess.graph_def

        output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

    with tf.gfile.GFile(out_pb_path, "wb") as f:  # 保存模型

        f.write(output_graph_def.SerializeToString())  # 序列化输出

    print("%d ops in the final graph." % len(output_graph_def.node))

  # 统计图中总的操作节点数

从pb模型中读取节点

#coding:utf-8

import tensorflow as tf

from tensorflow.python.framework import graph_util

tf.reset_default_graph()  # 重置计算图

output_graph_path = 'model/model_tfnew.pb'

with tf.Session() as sess:

    tf.global_variables_initializer().run()

    output_graph_def = tf.GraphDef()

    # 获得默认的图

    graph = tf.get_default_graph()

    with open(output_graph_path, "rb") as f:

        output_graph_def.ParseFromString(f.read())

        _ = tf.import_graph_def(output_graph_def, name="")

        # 得到当前图有几个操作节点

        print("%d ops in the final graph." % len(output_graph_def.node))

        tensor_name = [tensor.name for tensor in output_graph_def.node]

        print(tensor_name)

        print('---------------------------')

        # 在log_graph文件夹下生产日志文件,可以在tensorboard中可视化模型

        summaryWriter = tf.summary.FileWriter('log_graph/', graph)

        for op in graph.get_operations():

            # print出tensor的name和值

            print(op.name, op.values())


参考:https://blog.csdn.net/u010397980/article/details/84889174

           https://blog.csdn.net/guyuealian/article/details/82218092

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

  • 在这篇tensorflow教程中,我会解释: 1) Tensorflow的模型(model)长什么样子? 2) 如...
    JunsorPeng阅读 3,596评论 1 6
  • 这篇文章是针对有tensorflow基础但是记不住复杂变量函数的读者,文章列举了从输入变量到前向传播,反向优化,数...
    horsetif阅读 1,237评论 0 1
  • 杨慧霞 洛阳 焦点讲师班二期 坚持分享第1106天 “稳一些,凡事想好了再做。“多么熟悉的一句话呀!小时...
    yhx慧心慧语阅读 420评论 3 1
  • 女儿喜欢一道菜:将小咸鱼或小虾米塞入茄块,调味,下锅煎炒。这道菜又咸又香,加上熟茄的软,那叫一个好吃。 冰箱里有茄...
    你是谁谁是我阅读 281评论 0 1
  • 1.事件:和朋友聊天。 2.感受:亲切、高兴、着急、平静。 3.想法:见到好朋友感到很亲切、高兴。听到朋友说对孩子...
    王穆宁阅读 206评论 0 0

友情链接更多精彩内容