预测实现
在上一篇文章中实现了装甲板id识别的网络训练并保存为了ckpt文件
https://www.jianshu.com/p/191337a9a819
虽然全连接的网络精度也就那样了,但是还是练习一下用现有的网络进行装甲板id预测
- 复现网络
#网络搭建
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(500,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(50,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(8,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])
- 加载参数
#加载参数
ckpt_path = "./checkpoint/armor_id.ckpt"
if(os.path.exists(ckpt_path + ".index")):
print("--load modle--")
model.load_weights(ckpt_path)
else:
print('----------------------------------------------error')
- 输入数据处理
#图片读取与处理
img = tf.io.read_file (test_img_path)
img_raw = tf.image.decode_bmp (img)
img_raw = tf.cast(img_raw,dtype=tf.float32)
x_predict = tf.convert_to_tensor(img_raw)
x_predict = tf.reshape(x_predict,[1,-1])
- 代码整体实现
import tensorflow as tf
import os
if __name__ == '__main__':
test_img_path = './armor_dataset/8/8_47.bmp'
#网络搭建
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(500,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(50,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
tf.keras.layers.Dense(8,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])
#加载参数
ckpt_path = "./checkpoint/armor_id.ckpt"
if(os.path.exists(ckpt_path + ".index")):
print("--load modle--")
model.load_weights(ckpt_path)
else:
print('----------------------------------------------error')
#图片读取与处理
img = tf.io.read_file (test_img_path)
img_raw = tf.image.decode_bmp (img)
img_raw = tf.cast(img_raw,dtype=tf.float32)
x_predict = tf.convert_to_tensor(img_raw)
x_predict = tf.reshape(x_predict,[1,-1])
#预测结果
result = model.predict(x_predict)
pred = tf.argmax(result,axis=1) #获取概率最大数值的下标
pred = pred + 1
print("预测id为:")
tf.print(pred)
遇到的坑
- 实际未读入数据
现象是每次输出结果随机变化 - 使用tfrecord解码的数据和使用原始数据解码的数据不一致
应当检查编码解码过程中的类型转换
https://www.jianshu.com/p/51659ec687f8
测试结果
测试图片
51.png
5177.png
847.png
还进行了其他数字的测试
测试图片基本都实现了正确的预测