import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x,y):
x=tf.cast(x, dtype=tf.float32)/255.
y=tf.cast(y, dtype=tf.int32)
return x,y
# 加载数据集
(x,y),(x_test,y_test)=datasets.fashion_mnist.load_data()
print(x.shape,y.shape)
batchsz=128
# 数据集划分成batch
db=tf.data.Dataset.from_tensor_slices((x,y))
db=db.map(preprocess).shuffle(10000).batch(batchsz)
db_test=tf.data.Dataset.from_tensor_slices((x_test,y_test))
db_test=db_test.map(preprocess).batch(batchsz)
# 测试查看
db_iter=iter(db)
sample=next(db_iter)
print('batch:',sample[0].shape,sample[1].shape)
# 构建前馈网络
model=Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(64,activation=tf.nn.relu),
layers.Dense(32,activation=tf.nn.relu),
layers.Dense(10),
])
# build
model.build(input_shape=[None,28*28])
model.summary()
optimizer=optimizers.Adam(lr=1e-3)
total_correct=0
total_num=0
# 训练
for epoch in range(30):
for step ,(x,y) in enumerate(db):
x=tf.reshape(x,[-1,28*28])
with tf.GradientTape() as tape:
logits=model(x)
y_onehot=tf.one_hot(y,depth=10)
loss_mse=tf.reduce_mean(tf.losses.MSE(y_onehot,logits))
loss_ce=tf.losses.categorical_crossentropy(y_onehot,logits,from_logits=True)
loss_ce=tf.reduce_mean(loss_ce)
grads=tape.gradient(loss_ce,model.trainable_variables)
# 用定义好的optimizer对trainable_variables更新
optimizer.apply_gradients(zip(grads,model.trainable_variables))
if step%100==0:
print(epoch,step,'loss:',float(loss_ce),float(loss_mse))
# test
for x,y in db_test:
# x:[b,28,28]=>[b,784]
# y:[b]
x=tf.reshape(x,[-1,28*28])
# [b,10]
logits=model(x)
# logits=>prob,[b,10]
prob=tf.nn.softmax(logits,axis=1)
# [b,10]=>[b]
pred=tf.argmax(prob,axis=1)
# pred.dtype是int64
pred=tf.cast(pred,dtype=tf.int32)
# pred:[b]
# y:[b]
# correct:[b]
correct=tf.equal(pred,y)
correct=tf.reduce_sum(tf.cast(correct,dtype=tf.int32))
total_correct+=int(correct)
total_num += x.shape[0]
acc=total_correct/total_num
print(epoch,'test acc:',acc)
6 fashion-mnist实战
最后编辑于 :
©著作权归作者所有,转载或内容合作请联系作者
- 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
- 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
- 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
推荐阅读更多精彩内容
- 主角多会嘴遁,反派死于话多。 扯了三篇咱们直接开始上手一个项目体验下机器学习的魅力,相关的理论知识,遇到了咱们再拎...
- 上一讲笔者和大家一起学习了如何使用 Tensorflow 构建一个卷积神经网络模型。本节我们将继续利用 Tens...
- Tensorflow之基于MNIST手写识别的入门介绍 TensorFlow实战:SoftMax手写体MNIST识...
- import numpy as np import matplotlib as plt import pandas...