可以使用windows自带的画图工具,画一个白底黑字的8*8像素的数字,然后使用sklearn识别
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import os
import cv2
def main():
# 1.加载手写数据集
digits = datasets.load_digits()
# 2.查看数据的基本信息
print(f"数据形状:{digits.images.shape}")
print(f"目标标签数量:{len(np.unique(digits.target))}")
# 3.可视化前几个样本
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.flat):
ax.imshow(digits.images[i], cmap="gray")
ax.set_title(f"digit: {digits.target[i]}")
ax.axis("off")
plt.tight_layout()
plt.show()
print("样本图像数字7的像素值")
print(digits.images[7])
# 4.数据预处理:将8*8图像展开为64维特征向量
n_samples = len(digits.images)
# (1797, 64)
X = digits.images.reshape((n_samples, -1))
y = digits.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)
# 5.创建并训练SVM模型
clf = svm.SVC(gamma="scale")
print("正在训练模型...")
clf.fit(X_train, y_train)
# 6.预测数据集
y_pred = clf.predict(X_test)
print(f"X_test shape: {X_test.shape}")
# 7.评估模型性能
# print("\n准确率:{:.2f}%".format(accuracy_score(y_test, y_pred) * 100))
# print("\n分类报告:")
# print(classification_report(y_test, y_pred))
# print("\n混淆矩阵:")
# print(confusion_matrix(y_test, y_pred))
# 8.可视化一些测试结果
# fig, axes = plt.subplots(2, 5, figsize=(10, 5))
# for i in range(10):
# ax = axes[i // 5][i % 5]
# ax.imshow(X_test[i].reshape(8, 8), cmap="gray")
# ax.set_title(f"test: {y_test[i]}, pred: {y_pred[i]}")
# ax.axis("off")
# plt.tight_layout()
# plt.show()
image_path = "D://test//num.png" # 例如:手写数字 3 的图片
predict_local_digit(image_path, clf)
# 加载本地图片并预测
def predict_local_digit(image_path, model):
"""
加载本地图片,预处理,并用训练好的模型预测数字
"""
if not os.path.exists(image_path):
print(f"❌ 错误:文件 {image_path} 不存在!")
return
# 读取图片(灰度)
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if img is None:
print("❌ 无法读取图片,请检查格式或路径。")
return
# 调整大小为 8x8 像素(与 sklearn digits 一致)
img_dealed = cv2.resize(img, (8, 8), interpolation=cv2.INTER_AREA)
print("原始图像的像素值")
print(img_dealed)
# 查看原始图像(可选)
plt.figure(figsize=(10, 5))
plt.imshow(img_dealed, cmap='gray')
plt.title("init img")
plt.axis('off')
plt.show()
# 0-255的像素颜色值近似转换到 0-15 范围
img_dealed = img_dealed // 16
print("归一化图像的像素值")
print(img_dealed)
# sklearn digits是黑底白字。如果你的图片是白底黑字则需要像素值反转
img_dealed = 15 - img_dealed
print(img_dealed)
# 剔除非法数值
img_dealed = np.clip(img_dealed, 0, 15).astype(np.uint8)
print(img_dealed)
# 2. 展平为 64 维向量
img_flat = img_dealed.reshape(1, -1)
# 3. 预测
pred = model.predict(img_flat)[0]
print(f"✅ 预测数字是: {pred}")
return pred
if __name__ == "__main__":
main()