TensorFlow入门4 -- 鸢尾(Iris)花分类,导入并解析训练数据集

参考:《深度学习图像识别技术--基于TensorFlow Object Detection API 和 OpenVINO

问题:假设你的生物学家,要对鸢尾(Iris)花分类。Iris有300多类,这里仅仅对Iris setosa,Iris virginica,Iris versicolor 这三类进行识别,如下图所示

本文范例程序下载地址:IrisClassifier.py

方法有很多种,比如,基于CNN的深度学习,直接学习图像。这里采用更加简单的方法,通过  sepals(花萼)和 petals(花瓣)的长度和宽度数据,进行模型训练和分类,这样更加适合初学者。

收集和构架数据集要花很多时间,幸运的是,已经有现成的Iris flower data set,which contains a set of 150 records under 5 attributes - Petal Length , Petal Width , Sepal Length , Sepal width and Class 如下图所示


基于这样的数据集(DataSet),可以让我们更加专注于学习机器学习的算法,而不需要花大量时间准备数据

第一步:下载训练数据集

我们需要把dataset文件下载到本地,然后把它转化为Python可以使用的数据结构。范例代码如下:


打开文件:C:\Users\tf\.keras\datasets\iris_training.csv


可以看到有120行数据,跟Iris data set wiki里面说的不大一样,不过没有关系,不影响训练。

前四列是Features,分别是:Petal Length , Petal Width , Sepal Length , Sepal width

第五列是label,分别用整型数来代表花的种类,对机器来说,用整型数比用字符串更加方便,但我们要知道整型数和花种类之间的映射:

0: Iris setosa

1: Iris versicolor

2: Iris virginica

第二步:解析(Parse)数据集

下载到本地的数据集iris_training.csv 是一个 CSV格式的文本文件, TensorFlow模型还不能直接使用。我们需要把feature和label的值按照TensorFlow模型的数据输入要求,重新格式化。

创建一个函数 parse_csv

输入参数是:iris_training.csv文件的一行(line),

功能是:把 前四个 feature 值合并成为一个List,并reshape成为一个 single tensor;把最后一个 label 变量reshape成为一个single tensor.

返回值: features 和 label tensors

如下所示:


tf.decode_csv函数功能是:Convert CSV records to tensors. Each column maps to one tensor.

tf.reshape(tensor, shape,name=None)函数的功能是:Given tensor, this operation returns a tensor that has the same values as tensor with shape shape

第三步:创建训练 tf.data.Dataset

TensorFlow's Dataset API 用于feeding data into a model,它负责读取data,并将data转换为适合模型训练的格式

代码如下所示:

执行结果如下所示:

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容