MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 是 Yann Lecun 大佬整理的手写数字数据集,分为以下四个部分:
dataset | name | details |
---|---|---|
Training set images | train-images-idx3-ubyte.gz | 60,000 个样本的像素值 |
Training set labels | train-labels-idx1-ubyte.gz | 60,000 个标签 |
Test set images | t10k-images-idx3-ubyte.gz | 10,000 个样本的像素值 |
Test set labels | t10k-labels-idx1-ubyte.gz | 10,000 个标签 |
数据读取
import gzip
import struct
def read_data(label_url,image_url):
with gzip.open(label_url) as flbl:
magic, num = struct.unpack(">II",flbl.read(8))
label = np.fromstring(flbl.read(),dtype=np.int8)
with gzip.open(image_url,'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII",fimg.read(16))
image = np.fromstring(fimg.read(),dtype=np.uint8).reshape(len(label),rows,cols)
return (label, image)
获取Train和Test
输入是 ohe 标志,输出是像素值与标签值构成的 tuple
def get_train(ohe=True):
(train_lbl, train_img) = read_data('DataSet/Mnist/train-labels-idx1-ubyte.gz','DataSet/Mnist/train-images-idx3-ubyte.gz')
train_img = train_img.reshape((*train_img.shape, 1)) # 添加通道维度
train_img = preprocessing_img(train_img) # 归一化处理
if ohe:
class_num = len(np.unique(train_lbl))
train_lbl = np_utils.to_categorical(train_lbl, num_classes=class_num) # 对标签进行 one hot 编码
return train_img, train_lbl
def get_test(ohe=True):
(val_lbl, val_img) = read_data('DataSet/Mnist/t10k-labels-idx1-ubyte.gz','DataSet/Mnist/t10k-images-idx3-ubyte.gz')
val_img = val_img.reshape((*val_img.shape, 1)) # 添加通道维度
val_img = preprocessing_img(val_img)
if ohe:
class_num = len(np.unique(val_lbl))
val_lbl = np_utils.to_categorical(val_lbl, num_classes=class_num) # 对标签进行 one hot 编码
return val_img, val_lbl