1. 前言
从毕业开始工作已经两个多月,这期间相当一部分的时间都用在了对MxNet的学习上,而在MxNet的众多部分中,又是pslite这一部分接触最多。因此,今天将我一直以来的学习过程中的心得和收获总结在这里,也为以后对MxNet的继续学习做一个铺垫
2. MxNet构成
MxNet作为一个深度学习框架,它最大的特点应该是分布式训练的支持了。从初次接触MxNet到现在的两个多月里,我认为MxNet主要有以下几个大的部分:
- symbol和graph,负责构建计算的图和反向传播的图
- Engine,负责根据图的节点的依赖关系,并行运算
- Parameter server和KVStore,负责参数的同步和传递
- operator,定义图的节点的op
- NDArray,数据的存储和计算
- Executor,对图进行处理用于计算
3. Parameter server
参数服务器的概念并不复杂,主要思想就是,将模型的参数保存在server中,另外通过worker来完成具体的计算任务,当每完成一个计算任务时就会得到对应参数的梯度,这时将梯度传送给server,由server来完成参数的更新,worker再从server那里取回更新后的参数。
在MxNet中,当我们需要进行分布式的训练时,就需要使用到它了。在MxNet中,为了完成参数在不同机器前的同步和更新,主要实现了两大部分。一是pslite,另一个是KVStore。
3.1 KVStore
为了更方便理解,我从KVStore
开始讲起。在MXNet中,可能很多人并不会直接操作KVStore,在官方文档中,甚至提到,不建议直接操作KVStore,但是,每个人在使用MXNet的过程中,都肯定用到了KVStore。其实,在我们建立module.Module
的时候,就会调用KVStore
的push
和pull
操作。
kvstore
主要分为两种,一种是单机下,一种是多机下。单机下又分为将参数存储在GPU显存和CPU内存上两种情况。
3.1.1 comm.h
在comm.h
文件中定义了Comm
类,该类用于设备间的信息传递,也就是communication。从Comm
类中派生出了两个子类CommCPU
用于CPU内存通信,CommDevice
用于GPU通信。
在Comm
类中定义了几个纯虚函数:
-
Init
:根据存储类型和shape初始化 -
Reduce
:输入NDArray的一个vector
,返回它们的和 -
Broadcast
:将一个NDArray复制到vector
中的每一个元素 BroadcastRowSparse
CommCPU
将数据复制到CPU内存中,在那里做操作。
-
Init
:初始化key
对应的KVStore,创建key
对应的NDArray
,保存在merge_buf_[key].merged
中。(不分配内存) -
Reduce
:将输入的vector<NDArray>& src
的每个元素求和并返回。当src
只有一个元素时,若不是sparse
的就直接返回src[0]
,若是则将src[0]
拷贝至merged_buf
返回。如果src
元素多于一个,那么:
if (stype == kDefaultStorage) {
std::vector<Engine::VarHandle> const_vars(src.size() - 1); // 定义engine pushasync的输入,用于engine根据该操作的输入来规划操作的执行
std::vector<NDArray> reduce(src.size());
CopyFromTo(src[0], &buf_merged, priority);
reduce[0] = buf_merged;
if (buf.copy_buf.empty()) { // copy_buf用于GPU数据拷贝至CPU,由于第0个元素存储在buf_merged,这里只需要src.size()-1个
buf.copy_buf.resize(src.size()-1);
for (size_t j = 0; j < src.size() - 1; ++j) {
// allocate copy buffer
buf.copy_buf[j] = NDArray(
src[0].shape(), pinned_ctx_, false, src[0].dtype());
}
}
CHECK(stype == buf.copy_buf[0].storage_type())
<< "Storage type mismatch detected. " << stype << "(src) vs. "
<< buf.copy_buf[0].storage_type() << "(buf.copy_buf)";
for (size_t i = 1; i < src.size(); ++i) {
CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); // 定义拷贝操作
reduce[i] = buf.copy_buf[i-1];
const_vars[i-1] = reduce[i].var(); // 定义拷贝操作的输入
}
Engine::Get()->PushAsync( // push该操作至engine,engine会根据输入来规划什么时候执行
[reduce, this](RunContext rctx, Engine::CallbackOnComplete on_complete) {
ReduceSumCPU(reduce);
on_complete();
}, Context::CPU(), const_vars, {reduce[0].var()},
FnProperty::kCPUPrioritized, priority, "KVStoreReduce");
}
Broadcast
CommDevice
-
Init
:将存储类型和shape信息存储在sorted_key_attrs_
中。 -
InitBuffersAndComm
:将vector<NDArray>& src
的context信息存储在devs
中,通过InitMergeBuffer
和sorted_key_attrs_
信息,将所有的KVPairs分别存储在GPU上。 -
Reduce
:和CommCPU的reduce
一样,同样也是为了累积求和。
3.1.2 kvstore.h
在kvstore.h
中定义了几个纯虚函数
-
Init
:根据参数定义的一组KVPairs初始化 -
Push
:将一组KVPairs执行Push操作 -
Pull
: Pull操作 -
Updater
:用于参数的更新
3.1.3 local 和 device
在初始化KVStore我们需要提供KVStore的类型,在MXNet中提供了local
和device
两种用于单机训练时的类型。不论哪种,都在文件kvstore_local.h
中定义。两者最主要的区别就是对Comm
的选择,local
会使用CommCPU
来进行comm_
的初始化,device
使用CommDevice
来初始化。
类KVStoreLocal
有以下几个重要的方法:
-
Init
:设置key的类型(str
或者int
),进行初始化。初始化的方法是使用comm_
的初始化方法,同时还会在local_
保存一个pinned_ctx_
类型的拷贝。pinned_ctx_
指的是不会被移出cache
的内存。 -
Push
:根据输入的KVPairs,使用comm_
的Reduce
方法,进行相同key的value的求和。并且如果注册了updater_
的话,会调用updater_
进行更新。在进行更新之前,如果是在GPU端更新,会先将保存在local_
的参数拷贝至GPU。 -
Pull
:Pull方法主要的工作是将存储在local_
的参数复制到对应的输出中。
经过对这几个主要方法的理解,我们就清楚了KVStore的主要工作方式,也就对它对内存和显存的占用有了一个清晰的了解。具体的实现细节还是要参考源码去了解。
3.1.4 KVStoreDist
这篇博客的重点还是去试图了解分布式下的KVStore,当我们使用dist-*
去create KVStore的时候,就会使用到类KVStoreDist
。KVStoreDist
分两个主要部分,一个是worker,一个是server。
如果该节点是worker,首先会创建一个ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
这个ps::KVWorker
将在pslite部分具体解析,它是主要的完成push
和pull
操作的部分。
server的启动:在我们通过import mxnet
的时候,会导入kvstore_server
,而导入该文件会允许语句_init_kvstore_server_module()
,阅读该函数源码不难发现,它会判断当前节点是否是server节点,如果是就会调用server.run()
,然后调用c++代码的MXKVStoreRunServer
,也就是类KVStoreDist
的RunServer
方法,该方法会创建server_ = new KVStoreDistServer();
-
set_updater
:updater的设置是通过python端的函数定义来完成的,它通过ctype转换成为了c端的函数,并且通过pickle序列化为字符串传递给server。
当然,我们的主要注意力还是放在push
和pull
的实现上。 -
Push_
:push操作首先会通过comm_
进行Reduce
操作,并将结果存储在comm_buf_[key]
中,完成了本地的Reduce
后,调用EncodeDefaultKey
函数将存储为key : int
和val : NDArray
形式的KVPair,转化为PSKV
形式,该形式用于Push操作。之后会通过PushDefault
方法完成操作,该方法定义了函数push_to_servers
,将comm_buf_[key]
作为输入,通过Engine::Get()->PushAsync
方法完成push操作的异步执行(只是将任务发给Engine,由Engine完成调度)。Engine会在适当的时机执行push_to_servers
,该函数调用ps_worker_
的ZPush
方法来实现分布式的push。 -
PullImpl
:pull操作由该函数来完成,该函数会根据keys
将server端的结果获取到对应的NDArray
中。中间结果会保存在comm_buf_[key]
中,这里由于之前push
将该变量作为了输入,Engine在调度执行时会考虑到这点,保证所有对comm_buf_[key]
的操作都在对它的读入完成之后,也就是push完成之后(push将它作为了输入)。类似于Push_
操作,Pull操作定义了函数pull_from_servers
作为异步执行的函数,调用PushAsync
发送给Engine。pull_from_servers
函数调用了ps_worker_
的ZPull
方法来完成分布式的pull操作。
这里的分析只是简单的流程的总结,更多实现的细节可以通过阅读源码来了解。
3.1.5 KVStoreDistServer
如果当前节点是server,那么就会建立一个KVStoreDistServer
对象,由该对象完成对worker的push,pull
请求的处理。其中最重要的方法是DataHandleEx
,它根据RequestType
来调用相应的函数完成对数据的处理。
在KVStoreDistServer
的构造函数中,会执行ps_server_ = new ps::KVServer<char>(0);
它建立了一个ps::KVServer
对象,该对象调用ps_server_->set_request_handle(std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
将DataHandleEx
绑定在自己的request_handle_
上。
-
Run
:前面提到过,如果该节点是server会调用RunServer
方法,该方法就会调用if (server_) server_->Run();
阅读KVStoreDistServer
的源码发现,Run
仅仅只有一行exec_.Start();
。这一行会调用Executor exec_;
的Start
方法,源码如下
void Start() {
std::unique_lock<std::mutex> lk(mu_);
while (true) {
cond_.wait(lk, [this]{return !queue_.empty();}); // queue_为空,则等待被唤醒
Block blk = std::move(queue_.front()); // 取出queue头元素
queue_.pop();
lk.unlock(); // 释放锁,给其他线程操作queue
if (blk.f) { // 如果blk定义了一个function,则允许他
blk.f();
blk.p->set_value(); // 返回function的结果
} else {
blk.p->set_value(); break;
}
lk.lock(); // 获取锁,执行下一个循环
}
调用Executor
的Exec
方法,会在queue
中添加一个执行函数的block
,代码如下
void Exec(const Func& func) {
Block blk(func); // 建立block
auto fut = blk.p->get_future();
{
std::lock_guard<std::mutex> lk(mu_);
queue_.push(std::move(blk));
cond_.notify_one(); // 通知别的线程运行
}
fut.wait();
}
有了上面的知识,我们来看一下具体怎么处理数据。
-
DataHandleDefault
:该方法是默认的数据处理的方法,由于DataHandleEx
被绑定为了数据的处理函数,当RequestType
是kDefaultPushPull
,就会调用该函数。它会根据传入的信息,提取对应的key
,将对应的数据存储在store_[key]
。如果从worker来的request类型是push,就会分两种情况运行。一种是初始化的时候,由于初始化同样通过调用push来完成,因此初始化的push只会将store_[key]
设置为对应的值。另一种是初始化后,每一次的push都会进行相应的操作。这里每一次从任何一个worker来的某一个key的push操作,都会存储在updates.merged
中,并且除了第一次的push,之后的push会进行updates.merged += updates.temp_array;
也就是和之前的push相加。并且ApplyUpdates
只会在push数达到worker的个数的时候,才会真正地进行。也只有在ApplyUpdates
真正执行的时候才会将回复返回给worker。这样,就实现了同步。
对于server的讲解,这里也只是简单地描述它的同步和执行的简单机制,具体更多的实现细节,可以参考源码来了解。
3.2 pslite
通过前面的了解,我们知道了worker会使用ps_worker_
的ZPush
方法来完成push操作,使用ZPull
方法来完成pull操作。类似地,server会使用ps_server_
的request_handle_
来进行数据处理的传递,使用SimpleApp
的request_handle_
来完成Command
处理的传递。这一部分,我们就来了解它们的实现。
KVWorker
和KVServer
都定义在文件kv_app.h
中,它们都继承自SimpleApp
。
3.2.1 kv_app
kv_app
是MxNet主要应用的部分。ps-lite
实现了两个app
,一个是simple_app
,一个是kv_app
。
KVWorker
当数据从MXNet端的push函数传递到parameter server端时,调用了如下方法:
int ZPush(const SArray<Key>& keys,
const SArray<Val>& vals,
const SArray<int>& lens = {},
int cmd = 0,
const Callback& cb = nullptr) {
int ts = obj_->NewRequest(kServerGroup);
AddCallback(ts, cb);
KVPairs<Val> kvs;
kvs.keys = keys;
kvs.vals = vals;
kvs.lens = lens;
Send(ts, true, cmd, kvs);
return ts;
}
该方法将kServerGroup
作为数据传输对象,建立了KVPairs,通过Send
方法,将数据发送给server。Send
方法完成了数据从KVParis到Message的转换,然后调用Postoffice::Get()->van()->Send(msg);来执行数据的发送。
相应地,在执行pull操作的时候,调用了Pull_
方法,该方法首先定义了一个回调函数,该函数在完成pull操作后执行,具体来说就是当发出的请求都得到了回应后,会在Process
方法中执行下列函数:
// finished, run callbacks
if (obj_->NumResponse(ts) == Postoffice::Get()->num_servers() - 1) {
RunCallback(ts);
}
KVServer
前面说到过,KVServer会使用request_handle_来调用KVStore的数据处理函数。KVServer会在方法KVServer<Val>::Process
中调用request_handle_
,在这之前它会将得到的Message
转换为KVMeta
和KVPairs
。这样就完成了数据从接收到,再到传递给MXNet端的数据处理函数的过程。
由于时间有限,内容较多,就不一一介绍函数。
3.2.2 postoffice.cc
Postoffice
是一个类似于全局管理者的角色,它完成了环境初始化等必要工作
3.2.3 van
从前面的介绍我们看到,所有的数据在发送的最后,调用的都是van的send方法。van的具体实现类是ZMQVan
。由于本人对于zmq
也只是个初学者,这里有兴趣的同学可以去详细了解它的实现以及性能。
3.2.4 meta.proto
zmq在进行数据传输的时候,会建立socket,并且将字符串传递给对应的对象。在代码中,使用了protobuf来进行数据到字符串的转换工作。
3.2.4 SArray.h
SArray全名Shared array,它完成了在进行数据赋值过程中的零拷贝,及时是不同类型间数据的赋值,仅仅是将数据指向的指针进行赋值,同时将类型进行保存而已。
3.2.5 message.h
后记
今天已经很晚了,只能在pslite部分草草收尾,希望下次进行总结的时候能够做的更好。
总体来说,MXNet对于我这样一个初学者来说有很多可以学习的地方,并且它异步的实现和parameter server的设计,都是非常值得学习的内容。