参考:《深度学习图像识别技术--基于TensorFlow Object Detection API 和 OpenVINO》
接上节,本文主要解释构建特征列,方便后续的函数找到对应的特征数据。
注意:范例中的代码也是相当烧脑,待文本慢慢解析后,会用更清晰明了的方式实现。
首先,回忆一下《从数据的角度理解TensorFlow鸢尾花分类程序4》一节中,print出来的变量(train_x, train_y), (test_x, test_y) 。上节说了,特征值和特征标签都存储在这四个变量里面。以train_x为例,train_x是Dataframe类型的变量,如下图所示:
Dataframe类型类似一个表格,由行号、每列的键值和数值组成,要引用每列的值,可以用每列的键值名来引用。
可以看出train_x每列的键值是['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth'],这些键值也是特征点的名称。把特征点的名称告诉模型,模型才能使用。特征列是一种数据结构,告知模型如何解读每个特征中的数据,不仅要告诉模型特征的名称,还要告诉模型特征的数据类型。
若特征列中的特征名字与train_x中的特征名字不一致,则模型无法找到特征数据。
综上所示,可以把范例里面的代码改写为:
这样更加清晰,容易懂。
my_feature_column的打印信息中:_NumericColumn(key='SepalLength', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),_NumericColumn是一个类(class),小括号里面是该类的属性和方法。
注意:范例里面的写法也有好处,就是当特征种类很多的时候,比如100个,像本文这样手动指定特征列每列的key 和 类型,就非常不适合了,还是用范例中的写法,简洁。