github:https://github.com/wuzy361/mnist_homework_project
github上显示上次更新已经是三周前,这个项目搁置了很久了。
在第一次更新中,主要工作是把数据集的接口写好了,直接把用二进制文件保存的数据集转化成python能直接处理的ndarray(numpy中的数组)。
其实有了规定格式的数据集后,使用sklearn库,就能很方便的做一些机器学习的工作了。以下是探索数据进行机器学习的一些经验:
1,numpy.squeeze()
关于这个函数最早是在http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 里看到的,里面的变量虽然不是numpy类型的而是torch.FloatTensor类型的。但是二者非常类似,方法名也一样。numpy.squezze()是这样的:
作用是把一维从数组里移除,不如shape本来是(1,2,3),代表一个三维张量,但其实和一个二维矩阵内容是一样的。混用(1,2,3)和(2,3)可能导致程序产生警告,甚至是错误,所以应该使用squeeze处理矩阵。
该代码应该改成:
2,navie bayes 和svm
navie bayes 分类器应该是最简单的分类器了,试验结果是这样的:
对于60000个训练数据集和1000个测试数据集来说,naive bayes用时非常短,但准确度很低,只有55.58%,毕竟NAIVE啊。
svm就不一样了,早就直到svm非常慢,在这不算小的数据集下,svm太慢了,所以我就先用pca降维了。
把原来的28×28 = 784维的数据降成80维,之后再训练:
精确度很高,到达了98.18%,
跟周志华的论文的对比实验接近,稍微差一点的是由于使用了pca,不可避免有点数据损失。