下载并制作CIFAR10数据集

import os
import glob
import sys
from urllib import request
import tarfile
import numpy as np
import cv2
from tqdm import tqdm

# CIFAR10数据集的类别
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']
# CIFAR10 数据下载地址
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'


# 下载数据集并解压
def download_and_uncompress_tarball(tarball_url, dataset_dir):
    """
    Downloads the `tarball_url` and uncompresses it locally.
    :param tarball_url: The URL of a tarball file.
    :param dataset_dir: The directory where the temporary files are stored.
    :return:
    """
    filename = tarball_url.split('/')[-1]
    filepath = os.path.join(dataset_dir, filename)

    def _progress(block_num, block_size, total_size):
        sys.stdout.write('\r>>>Downloading %s %.1f%%' % (
            filename, float(block_num * block_size / total_size * 100.0)
        ))
        sys.stdout.flush()

    filepath, _ = request.urlretrieve(tarball_url, filepath, _progress)
    print()
    stat_info = os.stat(filepath)
    print('Successfully downloaded', filename, stat_info.st_size, 'bytes.')
    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)  # 解压


# 解析CIFAR10数据集格式
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        # dict_: dict_.keys() -> dict_keys([b'batch_label', b'labels', b'data', b'filenames'])
        dict_ = pickle.load(fo, encoding='bytes')
    return dict_


# 将numpy data 转变成图片格式
def numpy_data_to_image(data_batch_files):
    # ================================================
    # 数据集处理
    # ================================================
    # 以下的注释以建立训练集为例,测试集同理
    # 将五个data_batch的数据一起放到一个 data label列表中
    data = []  # len(data) == 50000
    labels = []
    for file in data_batch_files:
        file_dict = unpickle(file)
        # 这儿的list(*)相当于将形状为(10000, 3072)的二维numpy数组,变为长度为10000的列表
        # 列表的每一项为3072长度的numpy数组
        data += list(file_dict[b'data'])
        labels += list(file_dict[b'labels'])
    # print(labels)   # [6, 9, 9, 4, 1, 1, 2, 7, 8, 3, ..., ] 50000长度
    # print(data)   # [array([  0,   1,   1, ..., 198, 195, 198], dtype=uint8), array([...]), ..., ] 50000长度

    # 将data列表中的每个array数组reshape成image格式
    images = np.reshape(data, [-1, 3, 32, 32])
    return images, labels


# 保存图片
def save_image(images, save_folder, labels):
    for i in tqdm(list(range(images.shape[0]))):
        img_data = images[i, ...]  # img_data.shape -> (3, 32, 32)
        # 变成cv2存储图片的格式 ,通道最后
        img_data = np.transpose(img_data, [1, 2, 0])
        # 继续变成cv2存储图片的格式, 通道顺序为 BRG
        img_data = cv2.cvtColor(img_data, cv2.COLOR_RGB2BGR)  # 说明cifar10的通道顺序是RGB

        # 为将要保存的图片命名,每一类图片放在一个文件夹中
        img_folder_path = "{}/{}".format(save_folder, CLASSES[labels[i]])

        if not os.path.exists(img_folder_path):
            os.makedirs(img_folder_path)

        cv2.imwrite("{}/{}.jpg".format(img_folder_path, str(i)), img_data)


def make_dataset(batches_folder, save_folder, is_train=True):

    if is_train:
        files = glob.glob(batches_folder + '/data_batch_*')
        save_folder = os.path.join(save_folder, 'image/train')
    else:
        files = glob.glob(batches_folder + '/test_batch*')
        save_folder = os.path.join(save_folder, 'image/test')

    # 将numpy 数据转成image格式
    images, labels = numpy_data_to_image(files)

    # 保存图片
    flag = 'train' if is_train else 'test'
    print("\t>>>Making %s dataset, starting..." % flag)
    save_image(images, save_folder, labels)
    print("Successfully make %s dataset." % flag)


if __name__ == '__main__':

    # 自己想要设定的数据路径,下载后的cifar10数据集会存储在当前路径下
    DATA_PATH = '/home/donglin/IMOOC/Python3+TensorFlow打造人脸识别智能小程序/cifar10/MY_CIFAR10/data'

    folders = os.path.join(DATA_PATH, 'cifar-10-batches-py')
    save_path = DATA_PATH

    # 如果 'cifar-10-batches-py' 文件夹不存在,则预先下载
    if not os.path.exists(folders):
        download_and_uncompress_tarball(DATA_URL, DATA_PATH)

    # 制作数据集
    make_dataset(folders, save_path, True)
    make_dataset(folders, save_path, False)
    print("Success!")

文件夹目录


image.png
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容