tensorflow API使用笔记 Bucketize

tensorflow分桶API,有好几个接口,其中带boundaries的接口C++实现如下:

template <typename T>
struct BucketizeFunctor<CPUDevice, T> {
  // PRECONDITION: boundaries_vector must be sorted.
  static Status Compute(OpKernelContext* context,
                        const typename TTypes<T, 1>::ConstTensor& input,
                        const std::vector<float>& boundaries_vector,
                        typename TTypes<int32, 1>::Tensor& output) {
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      auto first_bigger_it = std::upper_bound(
          boundaries_vector.begin(), boundaries_vector.end(), input(i));
      output(i) = first_bigger_it - boundaries_vector.begin();
    }

    return Status::OK();
  }
};
  • 输入:input tensor和boundaries_vector
  • 输出:output tensor

使用stl的upper_bound算法查找第一个大于输入值的bound,然后返回这个bound的偏移索引。

所以这里,用户需要指定一个boundary,并且boundary是不变的。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容