继续上一篇中的问题2,怎样对自定义的图片数据集进行训练呢?在参数和模型固定的情况下,增加训练集,有助于提高模型的泛化能力。在前面,我们是对下载下来的mnist数据集进行的训练,那么怎样应用到我们自己制作或采集的图片上呢?
如上面的图片,我们需要借助tfrecords文件,这是一种能将图像数据和标签放在一起的二进制文件,可以提高内存效率,实现快速的读取、存储等。
首先,新建一个writer用于制作tfreords文件,num_pic用于计数,f 打开的文件是如下的txt文件,每行是图片文件名和对应的数字,用for循环遍历每张图片。
这样,value[0]就是图片名,用于读图片,value[1]是对应标签,用于生成一行10列的一维数组,对应标签索引的值是1。
tf.train.Example包含Features字段,features里包含feature字典,键值对是img_raw和label对应的图像数据和标签数据,格式是bytelist和Int64list,这样就把所有信息存到一个example文件里,方便读取。
对于tfrecords文件的解析如下:
tf.train.string_input_producer生成一个先入先出的队列queue,用来读取数据,传入存储信息的tfrecords文件名列表,然后新建一个reader,把读取的每个样本保存到serialized_example中,通过tf.parse_single_example将协议内存块解析为张量,传入待解析的内存块,features字典映射,传入的字典键名要和制作的相同,返回1*784的图片张量和1*10的标签。
get_tfrecords是随机读取batch_size个样本,tf.train.shuffle_batch传入待乱序处理的tensors[img,label],返回batch_size组的img,label.
这样,我们通过tfrecords文件实现了对自定义图片数据集的读取,要注意的是还有修改反向传播和test文件中的接口。
在backward文件中要先导入生成tfrecord的文件,传入训练样本数,获得每轮喂入的batch,在会话中开启多线程协调器,提高图片和标签批获取效率。其中,tf.train.start_queue_runners()将会启动输入队列线程,填入训练样本到队列中,配合tf.train.Coordinator()在发生错误的情况下关闭线程。
同理,修改test文件,在获取batch时get_tfrecord()中的isTrain=False.
这样,我们就解决了问题2,运行backward文件训练数据,运行test文件观察准确度,达到一定准确度后可以结合上一篇中的appliaction文件对实际图片进行预测了。
需要注意的一点是,在train和test集上准确度符合要求,但实际预测有错误时,关注图片的预处理过程,如修改阈值等。
新手学习,欢迎指教!