Wide and Deep应用
-- based on Google Analytics Customer Revenue Prediction
最近发现一个回归任务,目标是预测谷歌商店的收入(链接),数据是用户在商店中的浏览数据,利用传统的GBDT方式可以做出一定的预测,baseline大约在1.4285(RSME),关于数据的处理和GBDT的训练在此不在赘述,仅记录下利用tensorflow训练wide and deep深度模型的结果。
wide and deep 原理
deep
请大家注意,在deep层,所有的输入都要是数字,所以类别列中的函数(蓝色),在输入层也要经过数值列的函数(红色)包装才行。
tensorflow输入构建
输入的特征分为两大类,类别特征和数值特征,针对这两种特征,神经网络都要将特征转化为数字输入,但是方法不太一样。
tf.feature_column是针对特征输入构建的包,里面有9类函数,他们都会返回Categorical Column或者Dense Column对象,关系图如下:
下面分别对数值列和类别列做介绍。
数值列
tf.numeric_column:数值输入列,表示输入的列为数值,指定key即可,也可以利用shape参数指定向量。
Bucketized Column
分桶,对于年龄、体重这种字段,如果不希望直接将数字提供给模型,可以将数值分桶,分桶后数据会形成onehot的结果
column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
categorical_column_with_identity
分类标识列 (Categorical Identity Column),通过索引输入的分类列,输入的值都是类别的id(从[0,类别数)),通过tf.feature_column.categorical_column_with_identity()可以将其映射成onehot编码
categorical_column_with_vocabulary_list
通过给定词汇表,将类别特征列转换为index
tensorflow的例子
tensorflow在其官方github的models中给出了一个完整的deep and wide代码,我仔细的看了看这个代码,然后将它修改简化,变成了我运行的脚本。只能说,看代码容易,写代码难。