Faster RCNN 推理 从头写 java (二) RPN网络预测

目录:

一: 输入输出

输入:

  • omg: 经过预处理过的图像, shape为 [1, 600, 800, 3].

输出:

  • cls: 每个anchor在pixel上的概率, shape为 [1, 37, 50, 49].
  • reg: 每个anchor在pixel上的回归值, shape 为 [1, 37, 50, 196].
  • feature: 经过VGG16后的feature map, shape 为 [1, 37, 50, 512].

二: 流程

  • 图片BGR 格式转换为 RGB 格式。
  • 图片缩放。
  • 图片均值中值化。

三: code by code

img 转换为tensorflow 的 Tensor

Tensor<Float> input = TypeConvertor.ndarrayToTensor(img);

预测

List<Tensor<?>> output = this.session.runner().
        feed(INPUT_NAME, input).
        fetch(OUTPUT_CLS_NAME).fetch(OUTPUT_REG_NAME).fetch(OUTPUT_FEATURE_MAP_NAME).
        run();

构建输出
0: cls
1: reg
3: feature

return new FasterRCnnRPN_Output(output.get(0), output.get(1), output.get(2));
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。