TensorFlow Serving 剖析1 - 读取 Servable

在最新 TensorFlow 中, TensorFlow Serving 是 TensorFlow Extended 的一部分
TensorFlow Extended 官方主页:https://www.tensorflow.org/tfx
TensorFlow Serving 的 Guide:https://www.tensorflow.org/tfx/guide/serving
TensorFlow Serving 的源码:https://github.com/tensorflow/serving.git
约定源码根目录为 tensorflow_serving/

本节主要介绍架构上如何读取一个 Servable,涉及到的类有 Servable, LoaderSourceAdapter
更多信息可以参考:https://www.tensorflow.org/tfx/serving/architecture

Servable

Servable 可以是任何类,用其他类对其进行封装。
ServableHandleServable 指针的封装,多了一个 id() 接口。这里用到类似 d-pointer 的设计模式:https://wiki.qt.io/D-Pointer。(Google 代码中 d-pointer 被广泛使用),不同的是 UntypedServableHandle 不是 ServableHandle 的派生类, AnyPtr 类似于 C 中的 void *

// core/servable_handle.h
template <typename T>
class ServableHandle {
 public:
  const ServableId& id() const { return untyped_handle_->id(); }
  T& operator*() const { return *get(); }
  T* operator->() const { return get(); }
  T* get() const { return servable_; }
  operator bool() const { return get() != nullptr; }

 private:
  friend class Manager;
  std::unique_ptr<UntypedServableHandle> untyped_handle_;
  T* servable_ = nullptr;
};

class UntypedServableHandle {
 public:
  virtual const ServableId& id() const = 0;
  virtual AnyPtr servable() = 0;
};

在源码中,有如下 handle,其中最重要的是 ServableHandle<SavedModelBundle>ServableHandle<SessoinBundle>

ServableHandle<string>
ServableHandle<int>
ServableHandle<int64>
ServableHandle<TestServable>
ServableHandle<float>
ServableHandle<StoragePath>
ServableHandle<SavedModelBundle>
ServableHandle<SessionBundle>

ServableData 是对 Servable 数据的封装,增加了 id(), status() 等接口,还增加了 DataOrDie() 等查错接口

// core/servable_data.h
template <typename T>
class ServableData {
 public:
  const ServableId& id() const { return id_; }
  const Status& status() const { return status_; }
  const T& DataOrDie() const;
  T& DataOrDie();
  T ConsumeDataOrDie();
 private:
  const ServableId id_;
  const Status status_;
  T data_;
}

// Die 的实现
// CHECK 在 tensorflow/core/platform/default/logging.h 中实现
#define CHECK(conditon) \
  if (TF_PREDICT_FALSE(!(condition))) \
  LOG(FATAL) << "Check failed: " #condition " "
template <typename T>
T& ServableData<T>::DataOrDie() {
  CHECK(status_.ok());
  return data_;
}

参考 https://www.tensorflow.org/tfx/serving/custom_servable 可知,要让框架支持 Servable,需要配套的 LoaderSourceAdapter

Loader

包含 Load()Unload()servable() 等接口
先通过 SourceAdapter 获得 Loader,然后使用 Load() 载入 Servable,使用 servable() 访问 Servable,最后使用 Unload() 卸载 Servable。部分 SourceAdapter 可以获得 Loader

SourceAdapter InputType OutputType
SourceAdapter ? ?
+ ErrorInjectingSourceAdapter ? ?
+ LimitedAdapter StoragePath StoragePath
+ IdentitySourceAdapter ? save as input
+ UnarySourceAdapter ? ?
+ PrefixStoragePathSourceAdapter StoragePath StoragePath
+++ FakeStoragePathSourceAdapter StoragePath StoragePath
+++ SavedModelBundleSourceAdapter StoragePath std::unique_ptr<Loader>
+++ SessionBundleSourceAdapter StoragePath std::unique_ptr<Loader>
+++ SimpleLoaderSourceAdapter ? std::unique_ptr<Loader>
+++++ SimpleLoaderSourceAdapterImpl ? std::unique_ptr<Loader>
+++++ FakeLoaderSourceAdapter Storage string
+++++ HashmapSourceAdapter Storage std::unordered_map<string, string>
// core/loader.h
class Loader {
 public:
  virtual Status EstimateResources(ResourceAllocation* estimate) const = 0;
  virtual Status Load() {
    return errors::Unimplemented("Load isn't implemented.");
  }
  struct Metadata {
    ServableId servable_id;
  }
  virtual Status LoadWithMetaData(const Metadata& metadata) { return Load(); }
  virtual void Unload() = 0;
  virtual AnyPtr servable() = 0;
};

using LoaderSource = Source<StoragePath>;

SourceAdapter

接口定义如下,比较关键的是 Adapt(),输入是 InputType 输出 OutputType,一般 InputType=StoragePathOutputType=std::unique_ptr<Loader>

// core/source_adapter.h
template <typename InputType, typename OutputType>
class SourceAdapter : public TargetBase<InputType>, public Source<OutputType> {
 public:
  void SetAspiredVersions(...) final;
  void SetAspiredVersionsCallback(...) final;
  virtual std::vector<ServableData<OutputType>> Adapt(
      const StringPiece servable_name,
      std::vector<ServableData<inputType>> versions) = 0;
  ServableData<OutputType> AdaptOneVersion(ServableData<InputType> input);
 private:
  typename Source<OutputType>::AspiredVersionsCallback outgoing_callback_;
  Notification outgoing_callback_set_;
};

SourceApdapter 是抽象类,子类派生关系如下。UnarySourceAdapterSimpleLoaderSourceAdatper 有更多的继承关系

ErrorInjectingSourceAdapter final
LimitedAdapter  final // 用于测试
IdentitySourceAdapter final
UnarySourceAdapter
  PrefixStoragePathSourceAdapter final
  FakeStoragePathSourceAdapter final 
  SavedModelBundleSourceAdapter final
  SessionBundleSourceAdapter final
  SimpleLoaderSourceAdapter
    SimpleLoaderSourceAdapterImpl final
    FakeLoaderSourceAdapter final
    HashmapSourceAdapter final

UnarySourceAdapter 通过 Convert() 将一个 InputType 转换为一个 OutputType,在 Adapt() 中调用 Convert() 实现接口。注意:1. Adapt() 被改为 private 的了,详情可以看这里的讨论 stackoverflow。2. Convert() 是纯虚函数。可以看出,设计者希望使用者继承 UnarySourceAdapter 然后实现自己的 Convert() 函数,Adapt() 保持不动

// core/source_adapter.h
template <typename InputType, typename OutputType>
class UnarySourceAdapter : public SourceAdapter<InputType, OutputType> {
 private:
  std::vector<ServableData<OutputType>>
  UnarySourceAdapter<InputType, OutputType>::Adapt(
      const StringPiece servable_name,
      std::vector<ServableData<InputType> versions) {
    std::vector<ServableData<OutputType>> adapted_versions;
    for () {
      ...
      OutputType adapted_data;
      Status adapt_status = Convert(version.DataOrDie(), &adapted_data);
    }
    return adapted_version;
  }
  virtual Status Convert(const InputType& data, OutputType* converted_data) = 0;
};

SimpleLoaderSourceAdapter:1. 包含两个 functioncreatorresource_estimator,可以通过 protected 构造函数设置。2. 用这两个函数实现了 Convert() 接口。3. 没有重写 Adapt(),使用基类 UnarySourceAdapter 的实现。4. Convert() 返回 Loader 的子类 SimpleLoader,里边间接用到了两个 function
虽然实现复杂,但是扩展 SimpleLoaderSourceAdapter 很简单:1. 继承 SimpleLoaderSourceAdapter。2. 实现的自己的 creatorresource_estimator (后者有默认实现)。3. 用 creatorresource_estimator 初始化基类。
resource_estimator 有一个默认实现,调用 EstimateNoResources() 获得

template <typename DataType, typename ServableType>
class SimpleLoaderSourceAdapter
    : public UnarySourceAdapter<DataType, std::unique_ptr<Loader>> {
 public:
  using Creator =
      std::function<Status(const DataType&, std::unique_ptr<ServableType>*)>;
  using ResourceEstimator =
      std::funciton<Status(const DataType&, ResourceAllocation*)>;
 protected:
  SimpleLoaderSourceAdapter(Creator creator,
                            ResourceEstimator resource_estimator);
  Status Convert(const DataType& data, std::unique_ptr<Loader>* loader) final {
    loader->reset(new SimpleLoader<ServableType>(func1, func2));
    return Status::OK();
  }
}

HashmapSourceAdapter 为例 官方教程 ,首先继承 SimpleLoaderSourceAdapter

// servables/hashmap/hashmap_source_adapter.h
using Hashmap = std::unordered_map<string, string>;
class HashmapSourceAdapter final
    : public SimpleLoaderSourceAdapter<StoragePath, Hashmap> {
  ...
};

实现 creator,使用默认的 resource_estimator

Status LoadHashmapFromFile(const string& path,
                           const HashmapSourceAdapterConfig::Format& format,
                           std::unique_ptr<Hashmap>* hashmap) {
  ...
}

// 伪代码
function creator = [config](const StoragePath& path, std::unique_ptr<Hashmap>* hashmap) {
  return LoadHashmapFromFile(path, config.format(), hashmap);
}
function resource_estimator = SimpleLoaderSourceAdapter<...>::EstimateNoResources();

creatorresource_estimator 初始化基类

  HasnmapSourceAdapter::HashmapSourceAdapter(
      const HashmapSourceAdapterConfig& config)
      : SimpleLoaderSourceAdapter<StoragePath, Hashmap>(creator, resource_estimator) {
}

其他的 Convert()Adapt()SimpleLoader 源码里都已经实现了。

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。