用Python纯撸一个识别手写数字的简单神经网络
import numpy as np
def sigmoid(x):
# s形 激活函数
s = 1 / (1 + np.exp(-x))
return s
class NeuralNetword:
"""神经网络"""
def __init__(self, inputnodes, hiddennodes, outputnodes, lr):
# 三层 输入层 隐藏层 输出层
self.inodes = inputnodes
self.hnodes = hiddennodes
self.onodes = outputnodes
self.lr = lr
# 权重
self.wih = np.random.normal(0.0, pow(self.hnodes, -0.5), (self.hnodes, self.inodes))
self.who = np.random.normal(0.0, pow(self.onodes, -0.5), (self.onodes, self.hnodes))
# 激活函数
self.ac_fun = lambda x: sigmoid(x)
def train(self, inputs_list, targets_list):
"""训练"""
inputs = np.array(inputs_list, ndmin=2).T
targets = np.array(targets_list, ndmin=2).T
hidden_inputs = np.dot(self.wih, inputs)
hidden_outputs = self.ac_fun(hidden_inputs)
final_inputs = np.dot(self.who, hidden_outputs)
final_outputs = self.ac_fun(final_inputs)
output_errors = targets - final_outputs
hidden_errors = np.dot(self.who.T, output_errors)
self.who += self.lr * np.dot((output_errors * final_outputs * (1.0 - final_outputs)),
np.transpose(hidden_outputs))
self.wih += self.lr * np.dot(hidden_errors * hidden_outputs * (1.0 - hidden_outputs),
np.transpose(inputs))
def query(self, inputs_list):
"""查询"""
inputs = np.array(inputs_list, ndmin=2).T
hidden_inputs = np.dot(self.wih, inputs)
hidden_outputs = self.ac_fun(hidden_inputs)
final_inputs = np.dot(self.who, hidden_outputs)
final_outputs = self.ac_fun(final_inputs)
return final_outputs
trains = None
cnn = NeuralNetword(784, 100, 10, 0.1)
with open('mnist_train.csv', 'r') as f:
trains = f.readlines()
for t in trains:
values = t.split(",")
inputs = (np.asfarray(values[1:]) / 255.0 * 0.99) + 0.01
targets = np.zeros(10) + 0.01
targets[int(values[0])] = 0.99
cnn.train(inputs, targets)
print('done!')
done!
with open('mnist_test.csv', 'r') as f:
trains = f.readlines()
alls = 0
succe = 0
for t in trains:
values = t.split(",")
inputs = (np.asfarray(values[1:]) / 255.0 * 0.99) + 0.01
out = (cnn.query(inputs))
image_array = np.asfarray(values[1:]).reshape((28, 28))
import matplotlib
import matplotlib.pyplot as plt
#plt.imshow(image_array, cmap='Greys', interpolation='None')
if list(out).index(max(list(out))) == int(values[0]):
succe += 1
alls += 1
print(succe/float(alls))
0.9517
准确率95%
# 用自己手写的图片识别
import cv2
img = cv2.imread('2.png', cv2.IMREAD_GRAYSCALE)
data = img.flatten()
inputs = 1 - ((np.asfarray(data) / 255.0 * 0.99) + 0.01)
out = (cnn.query(inputs))
image_array = (255 - np.asfarray(data)).reshape((28, 28))
plt.imshow(image_array, cmap='Greys', interpolation='None')
print(list(out).index(max(list(out))))
2