Wide and Deep应用

Wide and Deep应用

-- based on Google Analytics Customer Revenue Prediction
最近发现一个回归任务,目标是预测谷歌商店的收入(链接),数据是用户在商店中的浏览数据,利用传统的GBDT方式可以做出一定的预测,baseline大约在1.4285(RSME),关于数据的处理和GBDT的训练在此不在赘述,仅记录下利用tensorflow训练wide and deep深度模型的结果。

wide and deep 原理

deep

请大家注意,在deep层,所有的输入都要是数字,所以类别列中的函数(蓝色),在输入层也要经过数值列的函数(红色)包装才行。

tensorflow输入构建

输入的特征分为两大类,类别特征和数值特征,针对这两种特征,神经网络都要将特征转化为数字输入,但是方法不太一样。
tf.feature_column是针对特征输入构建的包,里面有9类函数,他们都会返回Categorical Column或者Dense Column对象,关系图如下:

1540287641170.png

下面分别对数值列和类别列做介绍。

数值列

tf.numeric_column:数值输入列,表示输入的列为数值,指定key即可,也可以利用shape参数指定向量。


1540288164410.png

Bucketized Column

分桶,对于年龄、体重这种字段,如果不希望直接将数字提供给模型,可以将数值分桶,分桶后数据会形成onehot的结果

column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])

categorical_column_with_identity

分类标识列 (Categorical Identity Column),通过索引输入的分类列,输入的值都是类别的id(从[0,类别数)),通过tf.feature_column.categorical_column_with_identity()可以将其映射成onehot编码

categorical_column_with_vocabulary_list

通过给定词汇表,将类别特征列转换为index

tensorflow的例子

tensorflow在其官方github的models中给出了一个完整的deep and wide代码,我仔细的看了看这个代码,然后将它修改简化,变成了我运行的脚本。只能说,看代码容易,写代码难。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容