基于svm的数字识别
本代码是基于opencv3的,所以编译有问题请换环境。
#include "opencv2/opencv.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/ml.hpp"
#include <iostream>
using namespace cv;
using namespace std;
using namespace cv::ml;
int main(int argc, char** argv)
{
// initial SVM
cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();
svm->setType(cv::ml::SVM::Types::C_SVC);
svm->setKernel(cv::ml::SVM::KernelTypes::LINEAR);
svm->setTermCriteria(cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6));
//生产数据
//训练字符每类数量
const int sample_mun_perclass = 80;
//训练字符类数
const int class_mun = 10;
const int image_cols = 16;
const int image_rows = 16;
cv::Mat sample;
cv::Mat label(cv::Size(0, 0), CV_32SC1); // 注意这里和2.x区别,3.x必须使用CV_32S,原因后面会说
for(int i =0;i<class_mun;i++){
for(int j=0;j<sample_mun_perclass;j++){
stringstream filename;
filename<<"D:/workspace/qt/opencvPractice/svmDigitRecognition/num/"<<i<<"/"<<j<<".jpg";
cv::Mat &&img = imread(filename.str(), cv::IMREAD_GRAYSCALE);
if (img.dims > 0) {
// 归一化, 一般认为能加快参数优化时的收敛速度、平衡模型权重等
resize(img,img,Size(image_cols,image_rows),(0,0),(0,0),CV_INTER_AREA);
cv::normalize(img, img, 1.0, 0.0, cv::NormTypes::NORM_MINMAX, CV_32FC1);
// 但如果你的数据量级本身相差不大,也可以不归一化直接convertTo即可
// img.convertTo(img, CV_32FC1);
sample.push_back(img.reshape(0, 1));
label.push_back<int>(i); // 注意push_back有模版重载,可能意外改变Mat类型
} //if
}
}
cout<<"start train ......"<<endl;
cv::Ptr<cv::ml::TrainData> &&trainDataSet = cv::ml::TrainData::create(sample, cv::ml::ROW_SAMPLE, label);
svm->trainAuto(trainDataSet);
svm->save("D:/workspace/qt/opencvPractice/svmDigitRecognition/svm.xml");
cout<<"end"<<endl;
int x =0,y = 0;
for(int num=0;num<10;num++){
for(int i =0;i<100;i++){
stringstream ss;
ss<<"D:/workspace/qt/opencvPractice/svmDigitRecognition/num/"<<num<<"/"<<i<<".jpg";
cv::Mat &&tk = cv::imread(ss.str(), cv::IMREAD_GRAYSCALE);
resize(tk,tk,Size(image_cols,image_rows),(0,0),(0,0),CV_INTER_AREA);
cv::normalize(tk, tk, 1., 0., cv::NormTypes::NORM_MINMAX, CV_32FC1);
// tk.convertTo(tk, CV_32FC1);
float r = svm->predict(tk.reshape(0, 1));
//cout<<"result:"<<r<<endl;
if(((int)r) == num){
x++;
}else{
y++;
}
}
}
cout<<"x="<<x<<" ,y="<<y<<" Accuracy rate:"<<(float)x/(x+y)<<endl;
cv::waitKey(0);
return 0;
}
代码中用到的图片资源,请在百度云上下载。
点我下载 密码:du5a