机器学习 Logistic Regression 识别手写数字

原理

    原理跟我上一篇文章讲述得其实差不多,不过从原来的二分类问题拓展到多分类问题。在此次实验当中,我建立了十个分类器,分别用于识别数字0~9。关于计算机为什么能够从矩阵上面,通过对各个像素点的权重进行简单线性计算就能实现分类,这个问题我至今还没有想明白。查阅百度,也没能得到很好的答案,只能大致得理解为不同的数字有不同的特征,通过对特征的识别完成分类。

过程

    首先建立代价函数,代价函数和上一篇文章一样,只需要注意样本和我们所求θ^n 的维度即可。

    在上一篇文章中,我们是自己手动实现梯度下降,去求得代价函数的最小值。而在这次,我们使用一个机器学习常用函数minimize(),输入代价函数,θ^n,和梯度下降函数(其实就是代价函数求偏导),再选择下降模式就可以了,所以过程很简单。

效果


   我们用训练集训练我们的模型,然后将训练集当作测试集来进行测试。目前准确率只能到达94.5%左右,因为Logistic Regression无法处理复杂的非线性问题,所以准确率还不是很理想。不过,就操作难易度来讲,总体还是很不错的。毕竟有些训练集简直是魔鬼,人也不能分辨出来好嘛!!!看看这随机抽取的100个样本!


具体实现

需要导入的库


数据加载


决策函数


代价函数

        

梯度下降函数


训练函数


预测函数


主函数


结语

    做完这个实验,感觉机器学习好像开始慢慢有作用了。我们已经能够运用Logistic Regression 来解决简单的实际问题。虽然这个我描述的很简单,但是里面数据处理需要很大的耐心,不然报bug会让人很心累。    

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容