参考CS231n,将KNN 跑起来了,成功将系统搞死,,内存和计算能力开销太大。
以下代码 切记不用轻易跑。。
数据集:
http://www.cs.toronto.edu/~kriz/cifar.html
code:
import os
import sys
import numpy as np
import pickle
def load_CIFAR_batch(filename):
"""
cifar-10数据集是分batch存储的,这是载入单个batch
@参数 filename: cifar文件名
@r返回值: X, Y: cifar batch中的 data 和 labels
"""
with open(filename,"rb") as f :
datadict = pickle.load(f,encoding='iso-8859-1')
print(filename)
X=datadict['data']
Y=datadict['labels']
X=X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
Y=np.array(Y)
return X, Y
def load_CIFAR10(ROOT):
"""
读取载入整个 CIFAR-10 数据集
@参数 ROOT: 根目录名
@return: X_train, Y_train: 训练集 data 和 labels
X_test, Y_test: 测试集 data 和 labels
"""
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT, "data_batch_%d" % (b, ))
X, Y=load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
X_train=np.concatenate(xs)
Y_train=np.concatenate(ys)
del X, Y
X_test, Y_test=load_CIFAR_batch(os.path.join(ROOT, "test_batch"))
return X_train, Y_train, X_test, Y_test
# 载入训练和测试数据集
X_train, Y_train, X_test, Y_test = load_CIFAR10('data/cifar/')
# 把32*32*3的多维数组展平
Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3) # Xtr_rows : 50000 x 3072
Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3) # Xte_rows : 10000 x 3072
class NearestNeighbor:
def __init__(self):
pass
def train(self, X, y):
"""
这个地方的训练其实就是把所有的已有图片读取进来 -_-||
"""
# the nearest neighbor classifier simply remembers all the training data
self.Xtr = X
self.ytr = y
def predict(self, X):
"""
所谓的预测过程其实就是扫描所有训练集中的图片,计算距离,取最小的距离对应图片的类目
"""
num_test = X.shape[0]
# 要保证维度一致哦
Ypred = np.zeros(num_test, dtype = self.ytr.dtype)
# 把训练集扫一遍 -_-||
for i in range(num_test):
# 计算l1距离,并找到最近的图片
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)
min_index = np.argmin(distances) # 取最近图片的下标
Ypred[i] = self.ytr[min_index] # 记录下label
return Ypred
nn = NearestNeighbor() # 初始化一个最近邻对象
nn.train(Xtr_rows, Y_train) # 训练...其实就是读取训练集
Yte_predict = nn.predict(Xte_rows) # 预测
# 比对标准答案,计算准确率
print ('accuracy: %f' % ( np.mean(Yte_predict == Y_test)))