正在学习斯坦福的cs231n课程,该课程使用的是CIFAR-10数据集
该数据集可在管网下载
http://www.cs.toronto.edu/~kriz/cifar.html
下载并解压,得到
如何导入数据
CIFAR-10数据集由pickle产生,因此也由pickle导入
import pickle
def load_file(filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
filename = 'D:/Download/cifar-10-batches-py/data_batch_1'
data = load_file(filename)
print(data.keys())//得到当前文件的一些基本信息
当前文件的一些基本信息
dict_keys(['batch_label', 'labels', 'data', 'filenames'])
NN分类的思想
NN分类并不需要训练,只需要将要判断的图和已有数据进行比较即可
比较时计算目标图与每一个数据图的范数一,范数一最小的数据图所属类别即为目标图类别
关于范数一与范数二代码如下
import numpy as np
import pickle
filename = 'xxx'
filename_test = 'xxx'
class NearestNeighbor:
"""docstring for NearestNeighbor"""
def __init__(self):
pass
# 导入数据
def load_file(self, filename):
with open(filename, 'rb') as fo:
data = pickle.load(fo, encoding='latin1')
return data
# 训练模型,NN只是简单的导入即可,X是数据,n*3072,Y是数据标签,n*1
def train(self, X, y):
self.Xtr = X
self.ytr = y
# 使用模型进行预测,X是test集的数据
def predict(self, X):
num_test = X.shape[0]# test数据个数
Ypred = np.zeros(num_test)# 初始化预测结果
for i in range(num_test):
distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)# 计算范数一
min_index = np.argmin(distances)# 寻范数一最小的数据
Ypred[i] = self.ytr[min_index]# 得到预测结果
return Ypred
net = NearestNeighbor()
data = net.load_file(filename)
test_batch = net.load_file(filename_test)
net.train(data['data'], data['labels'])
result = net.predict(test_batch['data'])
print(result)