第50章 CIFAR100数据集与ResNet网络实现

上一章讲解了ResNet模型及组件,也介绍了jax.example_libraries.stax下JAX内置的模型组件。有了这些准备工作,可以开始上手写代码了。所以,本章着手使用ResNet实现CIFAR100数据集的分类。

何为CIFAR100数据集

CIFAR10和CIFAR100都是含有标记小图的数据集,相比CIFAR10,CIFAR100含有100个分类,每类含600个图。其中,每个分类含有500张训练图和100张测试图。CIFAR100的100个分类又被分组到20个大类。

Superclass Classes
aquatic mammal beaver, dolphin, otter, seal, whale
fish aquarium fish, flatfish, ray, shark, trout
flowers orchids, poppies, roses, sunflowers, tulips
food containers bottles, bowls, cans, cups, plates
fruit and vegetables apples, mushrooms, oranges, pears, sweet peppers
household electrical devices clock, computer keyboard, lamp, telephone, television
household furniture bed, chair, couch, table, wardrobe
insects bee, beetle, butterfly, caterpillar, cockroach
large carnivores bear, leopard, lion, tiger, wolf
large man-made outdoor things bridge, castle, house, road, skyscraper
large natural outdoor scenes cloud, forest, mountain, plain, sea
large omnivores and herbivores camel, cattle, chimpanzee, elephant, kangaroo
medium-sized mammals fox, porcupine, possum, raccoon, skunk
non-insect invertebrates crab, lobster, snail, spider, worm
people baby, boy, girl, man, woman
reptiles crocodile, dinosaur, lizard, snake, turtle
small mammals hamster, mouse, rabbit, shrew, squirrel
trees hamster, mouse, rabbit, shrew, squirrel
vehicles 1 bicycle, bus, motorcycle, pickup truck, train
vehicles 2 lawn-mower, rocket, streetcar, tank, tractor

每个图含有一个“fine”标签(表示所属分类)和一个“coarse”标签(所属大类),大小为32x32像素。

图1 CIFAR分类

可以通过两种方式下载数据集,

版本 大小 md5sum
CIFAR-100 python version 161 MB eb9058c3a382ffc7106e4002c42a8d85
CIFAR-100 Matlab version 175 MB 6a4bfa1dcd5c9453dda6bb54194911f4
CIFAR-100 binary version (suitable for C programs) 161 MB 03b5dce01913d631647c71ecec9e9cb8

选择python版本。

  • 使用tensorflow_datasets下载。

下面分别介绍一下。

使用下载后的CIFAR100生成数据集

CIFAR-100 python version下载后,会有如下文件结构,

train
test
meta
file.txt~

其中,meta是数据集信息,train是训练集,test是测试集。通过如下代码可以读取数据集,


import pickle

def setup():
    
    def load(fileName: str):
        
        with open(file = fileName, mode = "rb") as handler:
            
            data = pickle.load(file = handler, encoding = "latin1")
            
        return data
    
    trains = load("../../Shares/cifar-100-python/train")
    tests = load("../../Shares/cifar-100-python/test")
    metas = load("../../Shares/cifar-100-python/meta")
    
    return trains, tests, metas
    
def train():

        trains, tests, metas = setup2()
    
    for key in trains.keys():
        
        print(f"key = {key}, len(trains[key]) = {len(trains[key])}")
    
    print("--------------------------------------------------")
    
    for key in tests.keys():
        
        print(f"key = {key}, len(tests[key]) = {len(tests[key])}")
    
    print("--------------------------------------------------")
    
    for key in metas.keys():
        
        print(f"key = {key}, len(metas[key]) = {len(metas[key])}”)
    
def main():
    
    train()

运行结果打印输出如下,


key = filenames, len(trains[key]) = 50000
key = batch_label, len(trains[key]) = 21
key = fine_labels, len(trains[key]) = 50000
key = coarse_labels, len(trains[key]) = 50000
key = data, len(trains[key]) = 50000
--------------------------------------------------
key = filenames, len(tests[key]) = 10000
key = batch_label, len(tests[key]) = 20
key = fine_labels, len(tests[key]) = 10000
key = coarse_labels, len(tests[key]) = 10000
key = data, len(tests[key]) = 10000
--------------------------------------------------
key = fine_label_names, len(metas[key]) = 100
key = coarse_label_names, len(metas[key]) = 20

具体说明如下,

  • filenames,长度为50000的列表,每一项代表对应一个图片文件名。
  • batch_label,批的信息。
  • fine_labels,所属分类。
  • coarse_labels,所属大类。
  • data,长度为50000 x 3072的的二位数据,每一行代表一幅图片的像素值。
使用tensorflow_datasets

import tensorflow as tf
import tensorflow_datasets as tfds
import jax

def setup():
    
    (trains, tests), meta = tfds.load("cifar100", data_dir = "/tmp/", split = [tfds.Split.TRAIN, tfds.Split.TEST], with_info = True, batch_size = -1)
    
    #tensorflow_datasets.show_examples(trains, metas)
        
    trains = tfds.as_numpy(trains)
    tests = tfds.as_numpy(tests)
    
    train_images, train_labels = trains["image"], trains["label"]
    test_images, test_labels = tests["image"], tests["label"]
    
    return (train_images, train_labels), (test_images, test_labels)
    
def train():
    
    (train_images, train_labels), (test_images, test_labels) = setup()
    
    print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape))
    
def main():
    
    train()
    
if __name__ == "__main__":
    
    main()

运行结果打印输出如下,


((50000, 32, 32, 3), (50000,)) ((10000, 32, 32, 3), (10000,))

keras.datasets数据集

def setup():
    
    (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()

    return (train_images, train_labels), (test_images, test_labels)

运行结果打印输出如下,


((50000, 32, 32, 3), (50000, 1)) ((10000, 32, 32, 3), (10000, 1))

ResNet残差模型实现

ResNet神经网络架构在上一章已经介绍,该网络创造性地使用“模块化‘的思维去对网络进行叠加,从而实现了数据在模块内部特征的传递不会丢失。

从下图可以看到,模块内部司机上是3个卷积通道互相叠加,形成一个瓶颈设计。对于每一个残差模块,使用3层卷积。这3层分别是1 x 1、3 x 3和1 x 1的卷积层,其中1 x 1层负责先减少后增加(恢复)尺寸,使3 x 3层具有较小的输入和输出尺寸瓶颈。

实现3层卷积结构的代码如下,


import jax.example_libraries.stax

def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

代码中输入的数据首先经过jax.example_libraries.stax.Conv()卷积运算,输出的为四分之一的输出维度,这是为了降低输入数据的整个数据量,为进行下一层[3, 3]的计算做准备。 jax.example_libraries.stax.BatchNorm()是批标准化层,jax.example_libraries.stax.Relu是激活层。

另外,这里使用了3个之前没有见过的类,首先需要知道,这些类的目的是将不同的计算通路进行一个组合。jax.example_libraries.stax.FanOut(2)是对数据进行复制,jax.example_libraries.stax.paralle(Main, Identity)是将主通计算结果与Identity通路计算结果进行同时并联处理,jax.example_libraries.stax.FanInSum()对并联处理的数据进行合并。

在数据传递过程中,ResNet模块使用了名为“shortcut”的“新石高速公路”,即集捷通道。shortcut连接相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度,如下图所示,

而且,整个网络依旧可以通过端到端的反向传播训练。代码如下,


def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

有的时候,除了判定是否对输入数据进行处理外,由于ResNet在实现过程中对数据的维度做了改变,因此,当输入的维度和要求模型输出的维度不同(input_channel不等于out_dim)时,需要对输入的维度进行padding操作。所谓padding操作就是补全数据,通过设置padding参数对数据进行补全。

ResNet网络实现

ResNet网络结构如下图所示,

图中一共提到5种深度的ResNet,分别是18、34、50、101和152,其中所有的网络都分为5个部分,分贝是conv1、conv2_x、conv3_x、conv4_x和conv5_x。

下面将对其进行实现。需要说明的是,ResNet完整的实现需要较高性能的显卡。为了便于演示,下面代码里做了修改,去掉了pooling层,并降低了filters的数目和每层的层数,这一点请务必注意。

完整实现的ResNet50代码如下,


import jax.example_libraries.stax

def IdentityBlock(kernel_size, filters):
    
    kernel_size_ = kernel_size
    filters1, filters2 = filters
    
    # Generate a main path
    def make_main(inputs_shape):
        
        return jax.example_libraries.stax.serial(
            
            jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm(),
            jax.example_libraries.stax.Relu,
            
            # Adjust according to the inputs automatically
            jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
            jax.example_libraries.stax.BatchNorm()
        )
    
    Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(Main,
                                            jax.example_libraries.stax.Identity),
                                            jax.example_libraries.stax.FanInSum,
                                            jax.example_libraries.stax.Relu
        )

def ConvolutionalBlock(kernel_size, filters, strides = (1, 1)):
    
    kernel_size_ = kernel_size
    filters1, filters2, filters3 = filters
    
    Main = jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(filters1, (1, 1), strides = strides, padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.Conv(filters3, (1, 1), strides = strides, padding = "SAME"),
        jax.example_libraries.stax.BatchNorm()
    )
    
    Shortcut = jax.example_libraries.stax.serial(
        jax.example_libraries.stax.Conv(filters3, (1, 1), strides, padding = "SAME")
    )
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.FanOut(2),
        jax.example_libraries.stax.parallel(
            Main,
            Shortcut
        ),
        
        jax.example_libraries.stax.FanInSum,
        jax.example_libraries.stax.Relu)

def ResNet50(number_classes):
    
    return jax.example_libraries.stax.serial(
        
        jax.example_libraries.stax.Conv(64, (3, 3), padding = "SAME"),
        jax.example_libraries.stax.BatchNorm(),
        jax.example_libraries.stax.Relu,
        
        jax.example_libraries.stax.MaxPool((3, 3), strides = (2, 2)),
        
        ConvolutionalBlock(3, [64, 64, 256]),
        
        IdentityBlock(3, [64, 64]),
        IdentityBlock(3, [64, 64]),
        
        ConvolutionalBlock(3, [128, 128, 512]),
        
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128]),
        IdentityBlock(3, [128, 128,]),
        
        ConvolutionalBlock(3, [256, 256, 1024]),
        
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        IdentityBlock(3, [256, 256]),
        
        ConvolutionalBlock(3, [512, 512, 2048]),
        
        IdentityBlock(3, [512, 512]),
        IdentityBlock(3, [512, 512]),
        
        jax.example_libraries.stax.AvgPool((7, 7)),
        
        jax.example_libraries.stax.Flatten,
        
        jax.example_libraries.stax.Dense(number_classes),
        
        jax.example_libraries.stax.LogSoftmax
    )

结论

本章介绍了CIFAR100的数据集的结构,也介绍了ResNet残差模块及网络实现,还是为了实战做准备。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容