So you want to create a new model!!
在本节中,我们将讨论用于定义检测模型的一些抽象。如果您想定义一个新的模型体系结构以进行检测并在Tensorflow Detection API中使用它,那么本节还应该作为需要编辑以使新模型正常工作的文件的高级指南。
DetectionModels(object_detection/core/model.py
)
为了使用我们提供的二进制文件进行训练,评估和导出,Tensorflow Object Detection API下的所有模型(Faser RCNN,Mask RCNN,SSD等)都使用DetectionModel
接口(请参阅完整定义object_detection/core/model.py
)。DetectionModel实现5个功能:
-
preprocess
:在图片输入到检测器之前,需要对图片进行必要的预处理(例如,缩放/移位/整形)。 -
predict
:生成可以传递给损失函数或后处理函数(postprocess)的“原始图片”预测张量。 -
postprocess
:将预测(predict)输出张量转换为最终检测的图片。 -
loss
:针对提供的真是标签(ground_truth)计算标量损失张量。 -
restore
:将检查点加载到Tensorflow图中。
给定DetectionModel
训练时间,我们通过以下函数序列传递每个图像批次,以计算可通过SGD优化的损失:
inputs (images tensor) -> preprocess -> predict -> loss -> outputs (loss tensor)
在eval时间,我们通过以下函数序列传递每个图像批次以生成一组检测:
inputs (images tensor) -> preprocess -> predict -> postprocess -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
一些规定:
-
DetectionModel
不应该对输入大小或宽高比做任何假设(也就是可以对任意图片进行检测) - 它们负责进行必要的调整大小/重新整形(参见preprocess
函数的注释 )。 - 输出类始终是在
[0, num_classes)
整数范围内的数,没有预先假设背景类别。 - 检测到的框将被解释为
[y_min, x_min, y_max, x_max]
格式化并相对于图像窗口标准化。 - 我们没有具体假设对分数的任何概率解释 - 仅仅进行了相对排序。因此,后处理功能的实现可以自由地输出对数,概率,校准概率或其他任何东西。
定义新的Faster R-CNN or SSD Feature Extractor
在大多数情况下,不会从头写DetectionModel
- 一般是创建一个新的功能提取器,供其中一个SSD或Faster R-CNN 的meta-architectures.模型使用。(meta-architectures是DetectionModel
子的类)。
注意:为了使下面的讨论有意义,建议首先熟悉Faster R-CNN 论文。
如果使用一种全新的网络架构(比如说,“InceptionV100”)进行分类,并希望了解InceptionV100如何作为检测的特征提取器(例如,使用Faster R-CNN)。
要使用InceptionV100,我们必须定义一个新的 FasterRCNNFeatureExtractor
并将其FasterRCNNMetaArch
作为输入传递给我们的构造函数。
在object_detection/meta_architectures/faster_rcnn_meta_arch.py
。分别定义了FasterRCNNFeatureExtractor
和FasterRCNNMetaArch
。
FasterRCNNFeatureExtractor
必须定义的几个功能:
-
preprocess
:在输入图像上运行检测器之前,运行对输入值进行的任何预处理。 -
_extract_proposal_features
:提取第一阶段区域提议网络(RPN)功能。 -
_extract_box_classifier_features
:提取第二阶段Box分类器功能。 -
restore_from_classification_checkpoint_fn
:将检查点加载到Tensorflow图中。
使用object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py
举一个例子。:
- 使用Slim Resnet-101分类检查点的权重来初始化此特征提取器的权重 ,在此检查点模型内部对图像进行了预处理,通过从每个输入图像中减去通道平均值。因此,需要实现预处理函数来重现相同的通道平均减法行为。
- 在slim中定义的“完整”resnet分类网络被分成两部分 - 除last “resnet block”之外的所有部分都被传入到
_extract_proposal_features
函数中,last “resnet block”传入到_extract_box_classifier_features function
函数中。一般情况下,可能需要进行一些实验来确定最佳层,以便将特征提取器“切割”为这两个部分,以实现FasterRCNN。
配置自己的模型参数
假设feature extractor不需要标准配置,理想情况下,希望能够简单地更改配置中的“feature_extractor.type”字段以指向新的功能提取器。为了让我们的API知道如何理解这种新类型,您首先必须使用模型构建器(object_detection/builders/model_builder.py
)编写新的feature extractor,其作用是从配置原型创建模型。
创建很简单---只需添加一个指针,该指针指向您在object_detection/builders/model_builder.py
文件顶部的一个SSD或FasterRCNN特征提取器类映射中定义的新的Feature Extractor类 。建议添加一个测试,object_detection/builders/model_builder_test.py
以确保解析新的proto将按预期工作。(在model_builder.py有个字典把自己的CNN模型添加进去就可以了)
把新模型做的更加性感一点
创建好模型之后,就可以使用新的模型!最终提示:
- 要节省调试时间,请首先尝试在本地运行配置文件(包括培训和评估)。
- 学习一定的学习率,以确定哪种学习率最适合新的模型。
- 一个小但通常很重要的细节:可能会发现有必要禁用BN训练(即,从分类检查点加载批处理规范参数,但在梯度下降期间不要更新它们)。