本文主要介绍如何使用tensorflow serving官方提供的hashmap sourceadapter。其实理解了如何使用这个hashmapsourceadapter
,也就真正能对serving进行二次开发,网上几乎找不到任何相关资料(stackoverflow上的问题Tensorflow-serving: Serving a hashmap,而为了真正让这个hashmapsourceadapter
有用,我把serving的代码全部看了一遍。
这里一共分为6个步骤,第7步为接口测试。
首先,官方的代码只是定义了 HashmapSourceAdapter
。没有注册。
1 定义 HashmapSourceAdapterCreator
如下
// register this source adapter
class HashmapSourceAdapterCreator {
public:
static Status Create(
const HashmapSourceAdapterConfig& config,
std::unique_ptr<SourceAdapter<StoragePath, std::unique_ptr<Loader>>>*
adapter) {
adapter->reset(new HashmapSourceAdapter(config));
return Status::OK();
}
};
同时,把这个类加为 HashmapSourceAdapter
的友元。
private:
friend class HashmapSourceAdapterCreator;
2 注册 HashmapSourceAdapter
REGISTER_STORAGE_PATH_SOURCE_ADAPTER(HashmapSourceAdapterCreator,
HashmapSourceAdapterConfig);
3 使用hashmap servable
这一步非常重要,当然这一步不是非要我这么做,但这是最简单的方法。添加这个步骤的原因在于,标准的C++编译程序时,如果一个文件中的代码如果没有被调用,它就会被编译器优化掉。所以,这其实是一个hack。
第一步,在 hashmap_source_adapter.h 定义一个函数
void loadHashmapServable();
第二步,在hashmap_source_adapter.cc 中实现这个函数。
void loadHashmapServable() {
LOG(INFO) << "load hashmap servable...";
}
第三步,在main函数的开头调用这个函数
tensorflow::serving::loadHashmapServable();
4 http中添加使用hashmap servable的接口
这一步,我们会在原有http接口的基础上,添加一个lookup接口。
一, 在ProcessRequest
中添加分支lookup
} else if (method == "lookup") {
status = ProcessLookupRequest(model_name, model_version, request_body,
output);
}
二,定义函数 ProcessLookupRequest
Status HttpRestApiHandler::ProcessLookupRequest(
const absl::string_view model_name,
const absl::optional<int64>& model_version,
const absl::string_view request_body, string* output) {
ModelSpec model_spec;
model_spec.set_name(string(model_name));
if (model_version.has_value()) {
model_spec.mutable_version()->set_value(model_version.value());
}
ServableHandle<std::unordered_map<string, string>> bundle;
TF_RETURN_IF_ERROR(core_->GetServableHandle(model_spec, &bundle));
std::unordered_map<std::string, std::string>::const_iterator got = bundle->find(request_body.data());
if (got == bundle->end()) {
output->assign(string("None"));
} else {
output->assign(got->second);
}
return Status::OK();
}
三,放开URL正则匹配的限制
prediction_api_regex_(
R"((?i)/v1/models/([^/:]+)(?:/versions/(\d+))?:(classify|regress|predict|lookup))"),
5 给hashmap servable加载的文件添加一个文件名
我们程序启动时会去模型目录下加载一个名为 data.csv
的文件。
const string fpath = io::JoinPath(path, "data.csv");
std::unique_ptr<RandomAccessFile> file;
TF_RETURN_IF_ERROR(Env::Default()->NewRandomAccessFile(fpath, &file));
该文件的格式如下:
key0,value0
key1,value1
tom,jerry
pete,henry
hello,world
good,bye
6 从配置文件启动TF serving
因为加入了hashmap servable,tensorflow serving不止支持一个platform,当tensorflow serving支持多个platform的时候需要从配置文件启动,命令如下:
tensorflow_model_server --port=8500 --rest_api_port=8501 --platform_config_file=./etc/platform.conf --model_config_file=./etc/models.conf
其中,platform.conf内容如下:
platform_configs {
key: "hashmap"
value {
source_adapter_config {
type_url: "type.googleapis.com/tensorflow.serving.HashmapSourceAdapterConfig"
}
}
}
platform_configs {
key: "tensorflow"
value {
source_adapter_config {
type_url: "type.googleapis.com/tensorflow.serving.SavedModelBundleSourceAdapterConfig"
value: "\302>\002\022\000"
}
}
}
models.conf内容如下:
model_config_list: {
config: {
name: "tensorflow",
base_path: "/data/models/tensorflow",
model_platform: "tensorflow",
model_version_policy: {
all: {}
}
}
config: {
name: "hash",
base_path: "/data/models/hash",
model_platform: "hashmap",
model_version_policy: {
all: {}
}
}
}
platform.conf的编写可以参考这个issue
how to write a tensorflow serving platform_config_file
7 测试
访问接口
curl -d 'hello' -X POST http://localhost:8501/v1/models/hash/versions/1:lookup
输出
world