sklearn识别手写数字

可以使用windows自带的画图工具,画一个白底黑字的8*8像素的数字,然后使用sklearn识别

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import os
import cv2

def main():
    # 1.加载手写数据集
    digits = datasets.load_digits()

    # 2.查看数据的基本信息
    print(f"数据形状:{digits.images.shape}")
    print(f"目标标签数量:{len(np.unique(digits.target))}")

    # 3.可视化前几个样本
    fig, axes = plt.subplots(2, 5, figsize=(10, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(digits.images[i], cmap="gray")
        ax.set_title(f"digit: {digits.target[i]}")
        ax.axis("off")
    plt.tight_layout()
    plt.show()
    print("样本图像数字7的像素值")
    print(digits.images[7])

    # 4.数据预处理:将8*8图像展开为64维特征向量
    n_samples = len(digits.images)
    # (1797, 64)
    X = digits.images.reshape((n_samples, -1))
    y = digits.target
    # 划分训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

    # 5.创建并训练SVM模型
    clf = svm.SVC(gamma="scale")
    print("正在训练模型...")
    clf.fit(X_train, y_train)

    # 6.预测数据集
    y_pred = clf.predict(X_test)
    print(f"X_test shape: {X_test.shape}")

    # 7.评估模型性能
    # print("\n准确率:{:.2f}%".format(accuracy_score(y_test, y_pred) * 100))
    # print("\n分类报告:")
    # print(classification_report(y_test, y_pred))
    # print("\n混淆矩阵:")
    # print(confusion_matrix(y_test, y_pred))

    # 8.可视化一些测试结果
    # fig, axes = plt.subplots(2, 5, figsize=(10, 5))
    # for i in range(10):
    #     ax = axes[i // 5][i % 5]
    #     ax.imshow(X_test[i].reshape(8, 8), cmap="gray")
    #     ax.set_title(f"test: {y_test[i]}, pred: {y_pred[i]}")
    #     ax.axis("off")
    # plt.tight_layout()
    # plt.show()

    image_path = "D://test//num.png"  # 例如:手写数字 3 的图片
    predict_local_digit(image_path, clf)


 # 加载本地图片并预测
def predict_local_digit(image_path, model):
    """
    加载本地图片,预处理,并用训练好的模型预测数字
    """
    if not os.path.exists(image_path):
        print(f"❌ 错误:文件 {image_path} 不存在!")
        return

    # 读取图片(灰度)
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        print("❌ 无法读取图片,请检查格式或路径。")
        return
    # 调整大小为 8x8 像素(与 sklearn digits 一致)
    img_dealed = cv2.resize(img, (8, 8), interpolation=cv2.INTER_AREA)
    print("原始图像的像素值")
    print(img_dealed)

    # 查看原始图像(可选)
    plt.figure(figsize=(10, 5))
    plt.imshow(img_dealed, cmap='gray')
    plt.title("init img")
    plt.axis('off')
    plt.show()

    # 0-255的像素颜色值近似转换到 0-15 范围
    img_dealed = img_dealed // 16
    print("归一化图像的像素值")
    print(img_dealed)
    # sklearn digits是黑底白字。如果你的图片是白底黑字则需要像素值反转
    img_dealed = 15 - img_dealed
    print(img_dealed)
    # 剔除非法数值
    img_dealed = np.clip(img_dealed, 0, 15).astype(np.uint8)
    print(img_dealed)

    # 2. 展平为 64 维向量
    img_flat = img_dealed.reshape(1, -1)

    # 3. 预测
    pred = model.predict(img_flat)[0]
    print(f"✅ 预测数字是: {pred}")
    return pred

if __name__ == "__main__":
    main()
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容