import os
import glob
import sys
from urllib import request
import tarfile
import numpy as np
import cv2
from tqdm import tqdm
# CIFAR10数据集的类别
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# CIFAR10 数据下载地址
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
# 下载数据集并解压
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""
Downloads the `tarball_url` and uncompresses it locally.
:param tarball_url: The URL of a tarball file.
:param dataset_dir: The directory where the temporary files are stored.
:return:
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def _progress(block_num, block_size, total_size):
sys.stdout.write('\r>>>Downloading %s %.1f%%' % (
filename, float(block_num * block_size / total_size * 100.0)
))
sys.stdout.flush()
filepath, _ = request.urlretrieve(tarball_url, filepath, _progress)
print()
stat_info = os.stat(filepath)
print('Successfully downloaded', filename, stat_info.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir) # 解压
# 解析CIFAR10数据集格式
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
# dict_: dict_.keys() -> dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
dict_ = pickle.load(fo, encoding='bytes')
return dict_
# 将numpy data 转变成图片格式
def numpy_data_to_image(data_batch_files):
# ================================================
# 数据集处理
# ================================================
# 以下的注释以建立训练集为例,测试集同理
# 将五个data_batch的数据一起放到一个 data label列表中
data = [] # len(data) == 50000
labels = []
for file in data_batch_files:
file_dict = unpickle(file)
# 这儿的list(*)相当于将形状为(10000, 3072)的二维numpy数组,变为长度为10000的列表
# 列表的每一项为3072长度的numpy数组
data += list(file_dict[b'data'])
labels += list(file_dict[b'labels'])
# print(labels) # [6, 9, 9, 4, 1, 1, 2, 7, 8, 3, ..., ] 50000长度
# print(data) # [array([ 0, 1, 1, ..., 198, 195, 198], dtype=uint8), array([...]), ..., ] 50000长度
# 将data列表中的每个array数组reshape成image格式
images = np.reshape(data, [-1, 3, 32, 32])
return images, labels
# 保存图片
def save_image(images, save_folder, labels):
for i in tqdm(list(range(images.shape[0]))):
img_data = images[i, ...] # img_data.shape -> (3, 32, 32)
# 变成cv2存储图片的格式 ,通道最后
img_data = np.transpose(img_data, [1, 2, 0])
# 继续变成cv2存储图片的格式, 通道顺序为 BRG
img_data = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR) # 说明cifar10的通道顺序是RGB
# 为将要保存的图片命名,每一类图片放在一个文件夹中
img_folder_path = "{}/{}".format(save_folder, CLASSES[labels[i]])
if not os.path.exists(img_folder_path):
os.makedirs(img_folder_path)
cv2.imwrite("{}/{}.jpg".format(img_folder_path, str(i)), img_data)
def make_dataset(batches_folder, save_folder, is_train=True):
if is_train:
files = glob.glob(batches_folder + '/data_batch_*')
save_folder = os.path.join(save_folder, 'image/train')
else:
files = glob.glob(batches_folder + '/test_batch*')
save_folder = os.path.join(save_folder, 'image/test')
# 将numpy 数据转成image格式
images, labels = numpy_data_to_image(files)
# 保存图片
flag = 'train' if is_train else 'test'
print("\t>>>Making %s dataset, starting..." % flag)
save_image(images, save_folder, labels)
print("Successfully make %s dataset." % flag)
if __name__ == '__main__':
# 自己想要设定的数据路径,下载后的cifar10数据集会存储在当前路径下
DATA_PATH = '/home/donglin/IMOOC/Python3+TensorFlow打造人脸识别智能小程序/cifar10/MY_CIFAR10/data'
folders = os.path.join(DATA_PATH, 'cifar-10-batches-py')
save_path = DATA_PATH
# 如果 'cifar-10-batches-py' 文件夹不存在,则预先下载
if not os.path.exists(folders):
download_and_uncompress_tarball(DATA_URL, DATA_PATH)
# 制作数据集
make_dataset(folders, save_path, True)
make_dataset(folders, save_path, False)
print("Success!")
文件夹目录
image.png