serving最核心的抽象是 servables。如果需要扩展serving,最需要完成的就是自定义servables。如何自定义servable可以参考官方文章,Creating a new kind of servable。但定义好servable后如何使用,官方和网上并没有给出说明。本文先对servable注册机制进行说明。后续会以如何使用官方的hashmap来进行具体介绍。
简单来说,对于用户实现的每一个servable,都需要注册。看官方的savedmodelsourceadapter,注册代码如下:
REGISTER_STORAGE_PATH_SOURCE_ADAPTER(SavedModelBundleSourceAdapterCreator,
SavedModelBundleSourceAdapterConfig);
其调用了一个宏 REGISTER_STORAGE_PATH_SOURCE_ADAPTER
,这个宏的定义如下:
#define REGISTER_STORAGE_PATH_SOURCE_ADAPTER(ClassCreator, ConfigProto) \
REGISTER_CLASS(StoragePathSourceAdapterRegistry, StoragePathSourceAdapter, \
ClassCreator, ConfigProto);
是对另一个更通用的宏 REGISTER_CLASS
的封装。REGISTER_CLASS
的定义如下:
// Registers a factory that creates subclasses of BaseClass by calling
// ClassCreator::Create().
#define REGISTER_CLASS(RegistryName, BaseClass, ClassCreator, config_proto, \
...) \
REGISTER_CLASS_UNIQ_HELPER(__COUNTER__, RegistryName, BaseClass, \
ClassCreator, config_proto, ##__VA_ARGS__)
#define REGISTER_CLASS_UNIQ_HELPER(cnt, RegistryName, BaseClass, ClassCreator, \
config_proto, ...) \
REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator, \
config_proto, ##__VA_ARGS__)
#define REGISTER_CLASS_UNIQ(cnt, RegistryName, BaseClass, ClassCreator, \
config_proto, ...) \
static ::tensorflow::serving::internal::ClassRegistry< \
RegistryName, BaseClass, ##__VA_ARGS__>::MapInserter \
register_class_##cnt( \
(config_proto::default_instance().GetDescriptor()->full_name()), \
(new ::tensorflow::serving::internal::ClassRegistrationFactory< \
BaseClass, ClassCreator, config_proto, ##__VA_ARGS__>));
这段代码什么作用呢?
说白了,就是定义了一个宏,这个宏的作用是调用某个类的构造函数,在构造函数里完成了真正的注册机制。
真正的注册代码如下:
// Nested class whose instantiation inserts a key/value pair into the factory
// map.
class MapInserter {
public:
MapInserter(const string& config_proto_message_type, FactoryType* factory) {
InsertIntoMap(config_proto_message_type, factory);
}
};
private:
// Inserts a key/value pair into the factory map.
static void InsertIntoMap(const string& config_proto_message_type,
FactoryType* factory) {
LockableFactoryMap* global_map = GlobalFactoryMap();
{
mutex_lock lock(global_map->mu);
global_map
对于每一个调用了REGISTER_CLASS的类,都会被注册,也就是插入到这个全局map中,下一次需要这个类的时候,就从这个map中查找。
我实现了一个简单的宏定义为类构造函数的测试类。
#include <stdio.h>
class Test {
public:
Test(int x, int cnt) {
printf("ddddd %d, %d\n", x, cnt);
}
};
#define FUNC(x) static Test rr##x(x, __COUNTER__)
#define DF(x) FUNC(x)
//static Test a;
DF(3);
DF(4);
int main()
{
printf("main...\n");
return 0;
}
从上面可以看出,宏经过预处理后,其实就是一个全局静态变量。这个变量的初始化会在main函数执行之前,也就是类的注册会在整个程序执行之前,这一点非常重要。
如果没有这一点保证,在下面的代码中就会报找不到类的错误。
// Creates an instance of BaseClass based on a config proto embedded in an Any
// message.
//
// Requires that the config proto in the Any has a compiled-in descriptor.
static Status CreateFromAny(const google::protobuf::Any& any_config,
AdditionalFactoryArgs... args,
std::unique_ptr<BaseClass>* result) {
// Copy the config to a proto message of the indicated type.
string full_type_name;
Status parse_status =
ParseUrlForAnyType(any_config.type_url(), &full_type_name);
if (!parse_status.ok()) {
return parse_status;
}
const protobuf::Descriptor* descriptor =
protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName(
full_type_name);
if (descriptor == nullptr) {
return errors::Internal(
"Unable to find compiled-in proto descriptor of type ",
full_type_name);
}
std::unique_ptr<protobuf::Message> config(
protobuf::MessageFactory::generated_factory()
->GetPrototype(descriptor)
->New());
if (!any_config.UnpackTo(config.get())) {
return errors::InvalidArgument("Malformed content of Any: ",
any_config.DebugString());
}
return Create(*config, std::forward<AdditionalFactoryArgs>(args)...,
result);
}
而这个方法,就是serving中从pb消息反射具体platform类的工具。只有这一步执行成功,程序才能根据用途提供的platform配置文件 实例化出真正的 platform 实例。