【TensorFlow】关于cifar10数据集的使用(2020新版)

关于在《TensorFlow实战》这本书第5章中出现的cifar10,好多人都pip install cifar10发现失败:
no module named cifar10no modulw named cifar10_input,这是因为你需要下载一个tensorflow 的models,具体链接放在这里https://github.com/tensorflow/models/tree/r1.13.0,新的tensorflow由于变成2.x版了,所以没有models这个包,我这里用的是19年的branch。
下载完成后,将它解压缩并命名为models然后放到你安装tensorflow的那个目录下


紧接着要修改里边的文件
一共要修改两个地方:


删去那两行,用这两行代替
image.png


OK了你可以用了!



华丽的分割线(以上是2020年11月16日更新的,以下是2019年的)


之前看《TensorFlow实战》的时候就卡在了第五章“TensorFlow实现卷积神经网络”,原因是这里的cifar10数据集导入不进去。

cifar10.maybe_download_and_extract()

这里,如果你尝试的话会出现

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-12-02a754d7036a> in <module>()
----> 1 cifar10.maybe_download_and_extract()

AttributeError: module 'cifar10' has no attribute 'maybe_download_and_extract'

找不到 maybe_download_and_extract() 方法。。。什么鬼!

这可咋整,然后我还从官网上下载了cifar-10-batches-py,170多M,然后data_dir。。还是不行,后来彻底放弃了,这憨批TensorFlow。
时隔俩月,我今天又头铁,查了无数资料,没有提到 module 'cifar10' has no attribute 'maybe_download_and_extract'这种错误的???怎么说??网上冲浪的各位难道用的都是2016版TensorFlow-model???
那肯定是model更新了(https://github.com/tensorflow/models.git)models模块在这里有需要可以去下载。
一筹莫展之际,书上的下一行代码引起了我注意:

images_train, labels_train = cifar10_input.distorted_inputs(
                                                data_dir=data_dir, batch_size=batch_size)

OK,数据可以从这个线索去找,于是我打开models下的cifar10_input.py文件查找到distorted_inputs函数如下:

def distorted_inputs(batch_size):
  """Construct distorted input for CIFAR training using the Reader ops.

  Args:
    batch_size: Number of images per batch.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  """
  return _get_images_labels(batch_size, tfds.Split.TRAIN, distords=True)

OK,那我再去找_get_images_labels,也是cifar10_input.py中(奇怪为什么同一个功能要用两个函数):

def _get_images_labels(batch_size, split, distords=False):
  """Returns Dataset for given split."""
  dataset = tfds.load(name='cifar10', split=split)
  scope = 'data_augmentation' if distords else 'input'
  with tf.name_scope(scope):
    dataset = dataset.map(DataPreprocessor(distords), num_parallel_calls=10)
  # Dataset is small enough to be fully loaded on memory:
  dataset = dataset.prefetch(-1)
  dataset = dataset.repeat().batch(batch_size)
  iterator = dataset.make_one_shot_iterator()
  images_labels = iterator.get_next()
  images, labels = images_labels['input'], images_labels['target']
  tf.summary.image('images', images)
  return images, labels

可以看到datasets 是由tfds的load方法加载的
tfds,也就是import tensorflow_datasets as tfds
找到cifar10_download_and_extract.py这个文件,发现

DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'

这是cifar的下载路径,然后我打开,发现500服务器拒绝访问,难怪下载不了

修改方案:DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'改为
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'

去掉s再访问发现是可行的(当场摔电脑)
保存,运行

images_train, labels_train = cifar10_input.distorted_inputs(
                                                data_dir=data_dir, batch_size=batch_size)
下载.png

等下载完成后程序自动解压。此时在C:\Users\***\tensorflow_datasets文件夹下可以找到“cifar10”和“download”这两个文件夹。
OK完成,可以进行接下来的学习了。


2019年6月20日编辑
今天我尝试了一下,发现改成'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'也能正常下载了,不知道为什么,昨天访问的时候是500服务器错误。
需要注意的一点是,无论有没有http(s),我们都需要搭VPN才能下载。毕竟你懂的。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容