将ckpt转成pb文件

  • Estimator

    # 使用scaffold加载模型参数并导出模型
    model = tf.estimator.Estimator(model_fn=model_fn, params=params, model_dir=model_dir, config=config)
    def serving_input_fn():
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({
            'x': tf.placeholder(tf.int32, [None, 4], name='x')
        })()
        return input_fn
    
    model.export_saved_model(model_dir, serving_input_fn, checkpoint_path)
    
  • Sess 对图进行序列化

    import tensorflow as tf
    from tensorflow.python.framework import graph_util
    
    def freeze_graph(path='model.ckpt', output='model.pb'):
        saver = tf.train.import_meta_graph(path+'.meta', clear_devices=True)
        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()
    
        with tf.Session() as sess:
            saver.restore(sess, path)
            output_graph_def = graph_util.convert_variables_to_constants(
                            sess=sess,
                            input_graph_def=input_graph_def,   # = sess.graph_def,
                            output_node_names=['output/scores'])
    
            with tf.gfile.GFile(output, 'wb') as fgraph:
                fgraph.write(output_graph_def.SerializeToString())
    freeze_graph('model.ckpt', 'model.pb')
    
  • 测试模型

    def model(path='./logs/pb/model.pb'):
        with tf.gfile.GFile(path, 'rb') as fgraph:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(fgraph.read())
    
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')
    
            input_x = graph.get_tensor_by_name('input_x:0')
            keep_prob = graph.get_tensor_by_name('dropout_keep_prob:0')
            pred = graph.get_tensor_by_name('output/scores:0')
    
            sess = tf.Session(graph=graph)
    
            return sess, input_x, keep_prob, pred
    
    def predict(x, sess, input_x, keep_prob, pred, topk=1):
        feed_dict = {
            input_x: x,
            keep_prob: 1.0
        }
        prob = sess.run(pred, feed_dict=feed_dict)
        return prob 
    
  • 加载 pb 模型

    def load_pb(path='model.pb'):
        with tf.gfile.GFile(path, 'rb') as fgraph:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(fgraph.read())
    
            return graph_def
    
  • 将两个图合并

    def combined_graph():
        with tf.Graph().as_default() as g_combine:
            with tf.Session(graph=g_combine) as sess:
                graph_a = load_pb('./logs/pb/model.pb')
                graph_b = load_pb('./logs/pb/model_rnn.pb')
    
                tf.import_graph_def(graph_a, name='')
                tf.import_graph_def(graph_b, name='')
    
                g_combine_def = graph_util.convert_variables_to_constants(
                            sess=sess,
                            input_graph_def=sess.graph_def,
                            output_node_names=['output/scores_rnn', 'output/scores'])
                tf.train.write_graph(g_combine_def, './logs/pb/', 'model_combine.pb', as_text=False)
    
  • 读取 ckpt 模型权重

    def read_ckpt(ckpt):
        reader = tf.train.NewCheckpointReader(ckpt)
        weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}
        return weights
    
    weights = read_ckpt('model.ckpt')
    
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。