上一章讲解了ResNet模型及组件,也介绍了jax.example_libraries.stax下JAX内置的模型组件。有了这些准备工作,可以开始上手写代码了。所以,本章着手使用ResNet实现CIFAR100数据集的分类。
何为CIFAR100数据集
CIFAR10和CIFAR100都是含有标记小图的数据集,相比CIFAR10,CIFAR100含有100个分类,每类含600个图。其中,每个分类含有500张训练图和100张测试图。CIFAR100的100个分类又被分组到20个大类。
| Superclass | Classes |
|---|---|
| aquatic mammal | beaver, dolphin, otter, seal, whale |
| fish | aquarium fish, flatfish, ray, shark, trout |
| flowers | orchids, poppies, roses, sunflowers, tulips |
| food containers | bottles, bowls, cans, cups, plates |
| fruit and vegetables | apples, mushrooms, oranges, pears, sweet peppers |
| household electrical devices | clock, computer keyboard, lamp, telephone, television |
| household furniture | bed, chair, couch, table, wardrobe |
| insects | bee, beetle, butterfly, caterpillar, cockroach |
| large carnivores | bear, leopard, lion, tiger, wolf |
| large man-made outdoor things | bridge, castle, house, road, skyscraper |
| large natural outdoor scenes | cloud, forest, mountain, plain, sea |
| large omnivores and herbivores | camel, cattle, chimpanzee, elephant, kangaroo |
| medium-sized mammals | fox, porcupine, possum, raccoon, skunk |
| non-insect invertebrates | crab, lobster, snail, spider, worm |
| people | baby, boy, girl, man, woman |
| reptiles | crocodile, dinosaur, lizard, snake, turtle |
| small mammals | hamster, mouse, rabbit, shrew, squirrel |
| trees | hamster, mouse, rabbit, shrew, squirrel |
| vehicles 1 | bicycle, bus, motorcycle, pickup truck, train |
| vehicles 2 | lawn-mower, rocket, streetcar, tank, tractor |
每个图含有一个“fine”标签(表示所属分类)和一个“coarse”标签(所属大类),大小为32x32像素。

可以通过两种方式下载数据集,
| 版本 | 大小 | md5sum |
|---|---|---|
| CIFAR-100 python version | 161 MB | eb9058c3a382ffc7106e4002c42a8d85 |
| CIFAR-100 Matlab version | 175 MB | 6a4bfa1dcd5c9453dda6bb54194911f4 |
| CIFAR-100 binary version (suitable for C programs) | 161 MB | 03b5dce01913d631647c71ecec9e9cb8 |
选择python版本。
- 使用tensorflow_datasets下载。
下面分别介绍一下。
使用下载后的CIFAR100生成数据集
CIFAR-100 python version下载后,会有如下文件结构,
train
test
meta
file.txt~
其中,meta是数据集信息,train是训练集,test是测试集。通过如下代码可以读取数据集,
import pickle
def setup():
def load(fileName: str):
with open(file = fileName, mode = "rb") as handler:
data = pickle.load(file = handler, encoding = "latin1")
return data
trains = load("../../Shares/cifar-100-python/train")
tests = load("../../Shares/cifar-100-python/test")
metas = load("../../Shares/cifar-100-python/meta")
return trains, tests, metas
def train():
trains, tests, metas = setup2()
for key in trains.keys():
print(f"key = {key}, len(trains[key]) = {len(trains[key])}")
print("--------------------------------------------------")
for key in tests.keys():
print(f"key = {key}, len(tests[key]) = {len(tests[key])}")
print("--------------------------------------------------")
for key in metas.keys():
print(f"key = {key}, len(metas[key]) = {len(metas[key])}”)
def main():
train()
运行结果打印输出如下,
key = filenames, len(trains[key]) = 50000
key = batch_label, len(trains[key]) = 21
key = fine_labels, len(trains[key]) = 50000
key = coarse_labels, len(trains[key]) = 50000
key = data, len(trains[key]) = 50000
--------------------------------------------------
key = filenames, len(tests[key]) = 10000
key = batch_label, len(tests[key]) = 20
key = fine_labels, len(tests[key]) = 10000
key = coarse_labels, len(tests[key]) = 10000
key = data, len(tests[key]) = 10000
--------------------------------------------------
key = fine_label_names, len(metas[key]) = 100
key = coarse_label_names, len(metas[key]) = 20
具体说明如下,
- filenames,长度为50000的列表,每一项代表对应一个图片文件名。
- batch_label,批的信息。
- fine_labels,所属分类。
- coarse_labels,所属大类。
- data,长度为50000 x 3072的的二位数据,每一行代表一幅图片的像素值。
使用tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
def setup():
(trains, tests), meta = tfds.load("cifar100", data_dir = "/tmp/", split = [tfds.Split.TRAIN, tfds.Split.TEST], with_info = True, batch_size = -1)
#tensorflow_datasets.show_examples(trains, metas)
trains = tfds.as_numpy(trains)
tests = tfds.as_numpy(tests)
train_images, train_labels = trains["image"], trains["label"]
test_images, test_labels = tests["image"], tests["label"]
return (train_images, train_labels), (test_images, test_labels)
def train():
(train_images, train_labels), (test_images, test_labels) = setup()
print((train_images.shape, train_labels.shape), (test_images.shape, test_labels.shape))
def main():
train()
if __name__ == "__main__":
main()
运行结果打印输出如下,
((50000, 32, 32, 3), (50000,)) ((10000, 32, 32, 3), (10000,))
keras.datasets数据集
def setup():
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar100.load_data()
return (train_images, train_labels), (test_images, test_labels)
运行结果打印输出如下,
((50000, 32, 32, 3), (50000, 1)) ((10000, 32, 32, 3), (10000, 1))
ResNet残差模型实现
ResNet神经网络架构在上一章已经介绍,该网络创造性地使用“模块化‘的思维去对网络进行叠加,从而实现了数据在模块内部特征的传递不会丢失。
从下图可以看到,模块内部司机上是3个卷积通道互相叠加,形成一个瓶颈设计。对于每一个残差模块,使用3层卷积。这3层分别是1 x 1、3 x 3和1 x 1的卷积层,其中1 x 1层负责先减少后增加(恢复)尺寸,使3 x 3层具有较小的输入和输出尺寸瓶颈。
实现3层卷积结构的代码如下,
import jax.example_libraries.stax
def IdentityBlock(kernel_size, filters):
kernel_size_ = kernel_size
filters1, filters2 = filters
# Generate a main path
def make_main(inputs_shape):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
# Adjust according to the inputs automatically
jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(Main,
jax.example_libraries.stax.Identity),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu
)
代码中输入的数据首先经过jax.example_libraries.stax.Conv()卷积运算,输出的为四分之一的输出维度,这是为了降低输入数据的整个数据量,为进行下一层[3, 3]的计算做准备。 jax.example_libraries.stax.BatchNorm()是批标准化层,jax.example_libraries.stax.Relu是激活层。
另外,这里使用了3个之前没有见过的类,首先需要知道,这些类的目的是将不同的计算通路进行一个组合。jax.example_libraries.stax.FanOut(2)是对数据进行复制,jax.example_libraries.stax.paralle(Main, Identity)是将主通计算结果与Identity通路计算结果进行同时并联处理,jax.example_libraries.stax.FanInSum()对并联处理的数据进行合并。
在数据传递过程中,ResNet模块使用了名为“shortcut”的“新石高速公路”,即集捷通道。shortcut连接相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度,如下图所示,
而且,整个网络依旧可以通过端到端的反向传播训练。代码如下,
def IdentityBlock(kernel_size, filters):
kernel_size_ = kernel_size
filters1, filters2 = filters
# Generate a main path
def make_main(inputs_shape):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
# Adjust according to the inputs automatically
jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(Main,
jax.example_libraries.stax.Identity),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu
)
有的时候,除了判定是否对输入数据进行处理外,由于ResNet在实现过程中对数据的维度做了改变,因此,当输入的维度和要求模型输出的维度不同(input_channel不等于out_dim)时,需要对输入的维度进行padding操作。所谓padding操作就是补全数据,通过设置padding参数对数据进行补全。
ResNet网络实现
ResNet网络结构如下图所示,
图中一共提到5种深度的ResNet,分别是18、34、50、101和152,其中所有的网络都分为5个部分,分贝是conv1、conv2_x、conv3_x、conv4_x和conv5_x。
下面将对其进行实现。需要说明的是,ResNet完整的实现需要较高性能的显卡。为了便于演示,下面代码里做了修改,去掉了pooling层,并降低了filters的数目和每层的层数,这一点请务必注意。
完整实现的ResNet50代码如下,
import jax.example_libraries.stax
def IdentityBlock(kernel_size, filters):
kernel_size_ = kernel_size
filters1, filters2 = filters
# Generate a main path
def make_main(inputs_shape):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
# Adjust according to the inputs automatically
jax.example_libraries.stax.Conv(inputs_shape[3], (1, 1), padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Main = jax.example_libraries.stax.shape_dependent(make_layer = make_main)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(Main,
jax.example_libraries.stax.Identity),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu
)
def ConvolutionalBlock(kernel_size, filters, strides = (1, 1)):
kernel_size_ = kernel_size
filters1, filters2, filters3 = filters
Main = jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters1, (1, 1), strides = strides, padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters2, (kernel_size_, kernel_size_), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.Conv(filters3, (1, 1), strides = strides, padding = "SAME"),
jax.example_libraries.stax.BatchNorm()
)
Shortcut = jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(filters3, (1, 1), strides, padding = "SAME")
)
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.FanOut(2),
jax.example_libraries.stax.parallel(
Main,
Shortcut
),
jax.example_libraries.stax.FanInSum,
jax.example_libraries.stax.Relu)
def ResNet50(number_classes):
return jax.example_libraries.stax.serial(
jax.example_libraries.stax.Conv(64, (3, 3), padding = "SAME"),
jax.example_libraries.stax.BatchNorm(),
jax.example_libraries.stax.Relu,
jax.example_libraries.stax.MaxPool((3, 3), strides = (2, 2)),
ConvolutionalBlock(3, [64, 64, 256]),
IdentityBlock(3, [64, 64]),
IdentityBlock(3, [64, 64]),
ConvolutionalBlock(3, [128, 128, 512]),
IdentityBlock(3, [128, 128]),
IdentityBlock(3, [128, 128]),
IdentityBlock(3, [128, 128,]),
ConvolutionalBlock(3, [256, 256, 1024]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
IdentityBlock(3, [256, 256]),
ConvolutionalBlock(3, [512, 512, 2048]),
IdentityBlock(3, [512, 512]),
IdentityBlock(3, [512, 512]),
jax.example_libraries.stax.AvgPool((7, 7)),
jax.example_libraries.stax.Flatten,
jax.example_libraries.stax.Dense(number_classes),
jax.example_libraries.stax.LogSoftmax
)
结论
本章介绍了CIFAR100的数据集的结构,也介绍了ResNet残差模块及网络实现,还是为了实战做准备。