TensorFlow学习笔记(14)使用SignatureDef保存和恢复RNN模型

环境:
Python 3.5.2
tensorflow : 1.11.0
ubuntu : 16.04

保存模型,github代码

  saved_model_dir='./model'                                                                                                                                                                                                                                     
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)                                                                                                                                                                                           
                                                                                                                                                                                                                                                                
  # input_x, keep_prob                                                                                                                                                                                                                                          
  inputs = {'input_x': tf.saved_model.utils.build_tensor_info(xs),                                                                                                                                                                                              
           'input_y': tf.saved_model.utils.build_tensor_info(ys),                                                                                                                                                                                              
            'keep_prob': tf.saved_model.utils.build_tensor_info(keep_prob)}                                                                                                                                                                                     
                                                                                                                                                                                                                                                                
  # prediction 为预测函数,恢复的时候要通过该函数来预测                                                                                                                                                                                                         
  outputs = {'prediction' : tf.saved_model.utils.build_tensor_info(prediction)}                                                                                                                                                                                 
                                                                                                                                                                                                                                                                
  signature = tf.saved_model.signature_def_utils.build_signature_def(inputs, outputs, 'test_sig_name')                                                                                                                                                          
                                                                                                                                                                                                                                                                
  with tf.Session() as sess:                                                                                                                                                                                                                                    
      sess.run(init)                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                
      for i in range(1000):                                                                                                                                                                                                                                     
          batch_xs, batch_ys = mnist.train.next_batch(100)                                                                                                                                                                                                      
          sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})                                                                                                                                                                          
          if i % 50 == 0:                                                                                                                                                                                                                                       
              print(compute_accuracy(sess, prediction,                                                                                                                                                                                                          
                  mnist.test.images[:1000], mnist.test.labels[:1000]))                                                                                                                                                                                          
                                                                                                                                                                                                                                                                
      builder.add_meta_graph_and_variables(sess, ['model_final'], {'test_signature':signature})                                                                                                                                                                 
      builder.save()   

恢复模型github代码

  saved_model_dir='./model'                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                
  signature_key = 'test_signature'                                                                                                                                                                                                                              
  input_key_x = 'input_x'                                                                                                                                                                                                                                       
  input_key_y = 'input_y'                                                                                                                                                                                                                                       
  input_key_keep_prob = 'keep_prob'                                                                                                                                                                                                                             
  output_key_prediction = 'prediction'                                                                                                                                                                                                                          

  with tf.Session() as sess:                                                                                                                                                                                                                                    
      meta_graph_def = tf.saved_model.loader.load(sess, ['model_final'], saved_model_dir)                                                                                                                                                                       
                                                                                                                                                                                                                                                                
      # 从meta_graph_def中取出SignatureDef对象                                                                                                                                                                                                                  
      signature = meta_graph_def.signature_def                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                
      # 从signature中找出具体输入输出的tensor name.                                                                                                                                                                                                             
      x_tensor_name = signature[signature_key].inputs[input_key_x].name                                                                                                                                                                                         
      y_tensor_name = signature[signature_key].inputs[input_key_y].name                                                                                                                                                                                         
      keep_prob_tensor_name = signature[signature_key].inputs[input_key_keep_prob].name                                                                                                                                                                         
      prediction_tensor_name = signature[signature_key].outputs[output_key_prediction].name                                                                                                                                                                     
                                                                                                                                                                                                                                                                
      # 获取tensor 并inference                                                                                                                                                                                                                                  
      input_x = sess.graph.get_tensor_by_name(x_tensor_name)                                                                                                                                                                                                    
      input_y = sess.graph.get_tensor_by_name(y_tensor_name)                                                                                                                                                                                                    
      keep_prob = sess.graph.get_tensor_by_name(keep_prob_tensor_name)                                                                                                                                                                                          
      prediction = sess.graph.get_tensor_by_name(prediction_tensor_name)                                                                                                                                                                                        

通过恢复的模型,来预测结果

                                                                                                                                                                                                                                                                
      # 测试单个数据                                                                                                                                                                                                                                            
      x = mnist.test.images[index].reshape(1, 784)                                                                                                                                                                                                              
      y = mnist.test.labels[index].reshape(1, 10)  # 转为one-hot形式                                                                                                                                                                                            
      print (y)                                                                                                                                                                                                                                                 
                                                                                                                                                                                                                                                                
      pred_y = sess.run(prediction, feed_dict={input_x: x, keep_prob : 1 })                                                                                                                                                                                     
      print (pred_y)                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                
      print("Actual class: ", str(sess.run(tf.argmax(y, 1))), \                                                                                                                                                                                                 
            ", predict class ",str(sess.run(tf.argmax(pred_y, 1))), \                                                                                                                                                                                           
            ", predict ", str(sess.run(tf.equal(tf.argmax(y, 1), tf.argmax(pred_y, 1))))                                                                                                                                                                        
            )                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                
      # 测试数据集                                                                                                                                                                                                                                              
      print(compute_accuracy(sess, prediction, input_x, keep_prob,                                                                                                                                                                                              
        mnist.test.images[:1000], mnist.test.labels[:1000]))     
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,753评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,668评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 166,090评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 59,010评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,054评论 6 395
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,806评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,484评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,380评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,873评论 1 319
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,021评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,158评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,838评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,499评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,044评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,159评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,449评论 3 374
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,136评论 2 356

推荐阅读更多精彩内容

  • Android 自定义View的各种姿势1 Activity的显示之ViewRootImpl详解 Activity...
    passiontim阅读 172,207评论 25 707
  • 2018“书香湖南·阅行者”暨“我的书屋我的梦”暑期少儿阅读实践活动——阳光行动第四场:江娅镇龙台村,阅读推广人正...
    玄矶阅读 1,436评论 0 3
  • 感赏儿子早上起床比较迅速。 感赏儿子大雪中走回家,袜子鞋子都湿了。 感赏儿子通过看书听课学习国内课程。 感赏老公昨...
    燕子重生scy阅读 147评论 2 2
  • 我想搭上一班地铁,虽然不知道会开向何方,我只在意沿途会有鲜花盛开,会有鸟语花香,我只在意停靠的一个小站,就在车门轻...
    漾儿阅读 163评论 0 0