手写数据识别
数据集任务类型
-
监督学习
分类
回归
-
无监督学习
聚类
密度估计
数据可视化
实现流程
1.模块引用
from sklearn.datasets import load_digits from sklearn
import svm # 支持向量机的模块
from sklearn.model_selection import train_test_split
from sklearn.externals
import joblib #模型保存文件对象
import numpy as np
import matplotlib.pyplot as plt
#魔法函数,用于matplotlib绘制的图显示在页面里,不需要plt.show()
%matplotlib inline
2.数据加载
digits = load_digits()
3.数据探查
#查看数据集的键名
digits.keys()
#查看手写字图片的形状
digits.images.shape
#相当于对images对象进行了降维操作
digits.data.shape
#查看手写字标签的形状
digits.target.shape
#查看第一张图片
plt.imshow(digits.images[0])
4.数据划分
#把数据集的数据和标签划分为训练集和测试集以及他们的标签
X_train,X_test,y_train,y_test= \
train_test_split(digits.data,digits.target,test_size=.25,random_state=42)
X_train.shape,X_test.shape,y_train.shape,y_test.shape
5.相关SVM模型(SVC)以及预测数据
C:默认值是1.0
C越大,即对误分类的惩罚增大,泛化能力弱,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高。
C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。
kernel :核函数,默认是rbf,可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’
线性:u'v
多项式:(gammau'v + coef0)^degree
RBF函数:exp(-gamma|u-v|^2)
random_state :数据洗牌时的种子值,int值
sigmoid:tanh(gammau'v + coef0)
degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。
gamma :‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’,则会选择1/n_features
coef0 :核函数的常数项。对于‘poly’和 ‘sigmoid’有用。
probability :是否采用概率估计?.默认为False
shrinking :是否采用shrinking heuristic方法,默认为true
tol :停止训练的误差值大小,默认为1e-3
cache_size :核函数cache缓存大小,默认为200
class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)
verbose :允许冗余输出?
max_iter :最大迭代次数。-1为无限制。
decision_function_shape :‘ovo’, ‘ovr’ or None, default=None3
svc = svm.SVC(C=100,gamma=0.001)
#模型使用(选择)
svc.fit(X_train,y_train)
#通过分类器fit来训练模型 #测试集 ——> 预测标签 <-对比-> 真实标签 y_pred = svc.predict(X_test)
6.绘制预测的数据图像
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
images_and_predictions = list(zip(X_test.reshape(450,8,-1),y_pred))
plt.figure(figsize=(14,9))
for i,(img,pred) in enumerate(images_and_predictions[:24],start=1):
plt.subplot(4,6,i,)
plt.subplots_adjust(hspace = 0.3)
plt.axis('off')
plt.imshow(img,cmap=plt.cm.gray_r,interpolation='nearest')
plt.title(f'预测值:{str(pred)}')
7.模型保存
joblib.dump(svc,'MyDigitsModel.pkl')
读取保存的文件
joblib.load('MyDigitsModel.pkl')
后记:
有需要的朋友可以搜索微信公众号:【知音库】
同时也是为了鼓励自己,坚持写笔记,希望可以共同进步。