tensorflow CNN图像分类中的数据shape变化

https://zhuanlan.zhihu.com/p/27288913的基础上,重写了tf.Graph。

    global_step = tf.Variable(0, trainable=False)
    # placeholder
    images = tf.placeholder(tf.float32, [BATCH_SIZE, 32, 32, 3], name='images')
    labels = tf.placeholder(tf.int32, (BATCH_SIZE,), name='labels')

    print("Done Initializing Training Placeholders")

labels不是one-hot模式,就是数字本身。
placeholder的第一维都是固定的batch_size。

    # Build a Graph that computes the logits predictions from the placeholder
    logits = CNN(images)

    # Calculate loss
    loss = cal_loss(logits, labels)

    # Build a Graph that trains the model with one batch of examples and
    # updates the model parameters.
    train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

logits的shape是(batch_size,10),是one-hot形式
cal_loss中,Logits的shape是(batch_size,10),而labels则是(batch_size,1),因此用的函数是tf.nn.sparse_softmax_cross_entropy_with_logits

训练部分:

    for step in range(1000):
        # Current batch number
        batch_nb = step % nb_batches

        # Current batch start and end indices
        start, end = utils.batch_indices(batch_nb, data_length, BATCH_SIZE)

        # Prepare dictionnary to feed the session with
        feed_dict = {images: X_train[start:end],
                     labels: y_train[start:end]}

        # Run training step
        _, loss_value = sess.run([train_step, loss], feed_dict=feed_dict)

        # Echo loss once in a while
        if step % 20 == 0:
            num_examples_per_step = BATCH_SIZE
            examples_per_sec = num_examples_per_step / duration
            sec_per_batch = float(duration)

            format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                          'sec/batch)')
            print(format_str % (datetime.now(), step, loss_value,
                                examples_per_sec, sec_per_batch))

检测部分:

newbatch = math.ceil(1000 / BATCH_SIZE)
preds = np.zeros((1000, NUM_CLASS), dtype=np.float32)
# 检测数据有1000,分为64大小的部分循环检测
for cnt in range(0, int(newbatch + 1)):
      # Compute batch start and end indices
      start, end = utils.batch_indices(cnt, 1000, BATCH_SIZE)
      # Prepare feed dictionary
      feed_dict = {images: X_test[start:end]}
      preds[start:end, :] = sess.run([logits], feed_dict=feed_dict)[0]#取第一维

precision = accuracy(preds, y_test)
print('Precision of teacher after training: ' + str(precision))

训练步长设置为0.1,正确率达到60%
训练步长设置为0.05,正确率达到65%
链接:https://github.com/yingtaomj/cnn-classification

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容