Learning Tensorflow part 2

说明一些tensorflow课程中的用法.

1. 加载数据

def normalize(images, labels):
  images = tf.cast(images, tf.float32)
  images /= 255
  return images, labels

train_dataset =  train_dataset.map(normalize)
train_dataset =  train_dataset.cache()

normalize将图片数据由[0, 255]归一化到[0, 1].

  • cast将tensorflow的数据改为另一种数据类型, 这里只是改为浮点型, 避免除操作使其全为0.
  • map将数据进行重新映射tf.data.Dataset.map, 可以参考官方文件
  • cache 要求数据迭代之后(包括map)需要进行cache, 否则下一次迭代不会使用已经cache的数据.

2. 构造神经网络

l0 = tf.keras.layers.Flatten(input_shape = (28, 28, 1))
  • Flatten为展平输入.
  • 这里是灰度图, 因此是[28, 28, 1].

3. 构造误差函数

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])
  • crossentropy loss 交叉熵损失.

4. 训练

BATCH_SIZE = 32
train_dataset = train_dataset.cache().repeat().shuffle(num_train_examples).batch(BATCH_SIZE)
test_dataset = test_dataset.cache().batch(BATCH_SIZE)

model.fit(train_dataset, epochs=5, steps_per_epoch=math.ceil(num_train_examples/BATCH_SIZE))

这里要打乱原有的数据

  • repeat()表示重复使用
  • shuffle为随机打乱次序
  • batch(32)表示每次迭代32个数训练.
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。