tensorflow 训练框架写法

tensorflow_framework

图片来源
对于训练数据和算法定义这块基本了解,但是对于算法训练这块,总觉得自己写的很奇怪,今天决定总结一下别人怎么写的,一点一点慢慢改善。


  1. tensorflow mnist tutorial
    这个教程感觉和之前看到的已经不一样了,tensorflow要大力推广一下Estimator和Dataset的框架,所以这个写法如下:
    在模型定义的最后:
### predictions是字典,包含输出的类别和概率
if mode == tf.estimator.ModeKeys.PREDICT:
  return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

### loss是labels与logits的交叉熵
if mode == tf.estimator.ModeKeys.TRAIN:
  optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-3)
  train_op = optimizer.monimize(loss=loss,global_step=tf.train.get_global_step())
  return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

### 如果不是以上两种模式,则当做EVAL处理
eval_metric_ops = {
  "accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])}
return tf.estimator.EstimatorSpec(mode=mode,loss=loss,eval_metric_ops=eval_metric_ops)

我对上述代码有个问题:这一段是写在 def cnn_model_fn(features, labels, mode)中的,如果是PREDICT模式,那么没有办法提供labels怎么办?

然后在主函数中:

mnist_classifier = tf.estimator.Estimator(
  model_fn=cnn_model_fn, model_dir="/tmp/mnist_convet_model")

train_input_fn = tf.estimator.inputs.numpy_input_fn(
  x = {"x": train_data},
  y = train_labels,
  batch_size = 100,
  num_epochs = None,
  shuffle = True)

mnist_classifier.train(input_fn = train_input_fn, steps = 20000, hooks=[logging_hook])

eval_input_fn = tf.estimator.inputs.numpy_input_fn(
  x = {"x": eval_data},
  y = eval_labels,
  num_epochs = 1,
  shuffle = False)

eval_results = mnist_classifier.evaluate(input_fn=eval_input_fn)

所以mode参数不是显式给定的,应该是通过调用estimator的不同方法而隐式确定,所以tensorflow应该有内部的方法去处理没有labels的问题,可能直接赋值0就可以了。

  1. 极客学院MNIST
    这个似乎是我之前看到过的版本,确实比较简明。
    其实必要条件也这么多,定义好train_op以及一系列metrics,在循环中得到batch input,然后训练,一定间隔后输出loss和metrics信息。
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
sess.run(tf.initialize_all_variables())
for i in range(20000):
  batch = mnist.train.next_batch(50)
  if i%100 == 0 :
    train_accuracy = accuracy.eval(feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0)
    print "setp %d, training accuracy %g"%(i, train_accuray)
  train_step.run(feed_dict={x:batch[0],y_=batch[1]},keep_prob:0.5)
  1. FCN_tensorflow in github
sess = tf.Session()

print("Setting up Saver...")
saver = tf.train.Saver()
summary_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
if ckpt and ckpt.model_checkpoint_path:
  saver.restore(sess, ckpt.model_checkpoint_path)
  print("Model restored...")

if FLAGS.mode == "train":
  for itr in xrange(MAX_ITERATION):
    train_images, train_annotations = train_dataset_reader.next_batch(FLAGS.batch_size)
    feed_dict = {image:train_images, annotation:train_annotations, keep_probability: 0.85}

    sess.run(train_op, feed_dict = feed_dict)

    if itr % 10 ==0:
      train_loss, summary_str = sess.run([loss,summary_op], feed_dict=feed_dict)
      print("Step: %d, Train_loss: %g" % (itr, train_loss))
      summary_writer.add_summary(summary_str, itr)

    if itr % 500 ==0:
      valid_images, valid_annotations = validation_dataset_reader.next_batch(FLAGS.batch_size)
      valid_loss = sess.run(loss, feed_dict = {image:valid_images, annotation:valid_annotations, keep_probability: 1.0})
      print("%s --> Validation_loss: %g" % (datetime.datatime.now(), valid_loss))
      saver.save(sess, FLAGS.logs_dir + "model.ckpt", itr)

这一个稍微复杂一些,但是目前看来,函数式的训练方法大体都是这样。

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 194,390评论 5 459
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 81,821评论 2 371
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 141,632评论 0 319
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 52,170评论 1 263
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 61,033评论 4 355
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 46,098评论 1 272
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 36,511评论 3 381
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 35,204评论 0 253
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 39,479评论 1 290
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 34,572评论 2 309
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 36,341评论 1 326
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 32,213评论 3 312
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 37,576评论 3 298
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 28,893评论 0 17
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 30,171评论 1 250
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 41,486评论 2 341
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 40,676评论 2 335

推荐阅读更多精彩内容