介绍
在Caffe2的设计中,一切操作皆是Op。其中数据IO操作相关的ops有CreateDBOp、PrefetchOperator、ImageInputOp等;用于初始化数据的操作ops有UniformFillOp、GaussianFillOp等;
还有真正用于计算的Ops像FullyConnectedOp、FullyConnectedGradientOp、ConvOp、ConvGradientOp等;用于多节点训练时梯度融合的操作如BroadcastOp、AllreduceOp、AllgatherOp等;
其它则还有些用于各种参数更新算法的Ops如AdamOp、AdagradOp、MomentumSGDOp等。
一般所有的Caffe2 Python前端代码写毕后,两张建好的static Graph就在背后生成出来了。一张Graph为init_net,上面包含了一些初始化操作像Parameters的filler操作等或者多节点通信时
整体通信环境构建时的操作像CreateCommonWorld、CloneCommonWorld等。通常init_net上的ops都只需执行一次完成初始化目的即可。另一张Graph则为net,它上面包含了我们正式模型的基本所
用ops像一些用于计算的Ops如Conv/FC等,还有些则是用于更新参数的Ops像AdamOp/MomentumSGDOp等。它在init_net执行完后再执行,需要执行多次,最终得到训练好的模型参数。
可以说Operator是Caffe2中最核心的元素之一,往往也是我们使用一个framework写AI程序所需接触最频繁的一个部件。
Caffe2中核心Operator的实现在两个class里面:OperatorBase与Operator,其中Operator是OperatorBase的一个子类。大多数我们用到的Op都是Operator的子类,只需要实现它的若干override
函数即可,但也有些特别情况像PrefetchOp直接从OperatorBase继承而来,因为我们想直接掌握它之上Op执行时的输出同步情况。
以下为caffe2核心代码中与Operator相关的一些代码文件。
core git:(master) ✗ ls operator
operator_c10wrapper.cc operator.cc operator_gradient.h operator_schema.cc operator_schema_test.cc
operator_c10wrapper.h operator_gpu_test.cc operator.h operator_schema.h operator_test.cc
我们将主要介绍一下OperatorBase与Operator这两个类的基本APIs及其实现,同时稍带着也会看些其它像schema、gradient等之类的operator features。
OperatorBase
Observer和ObservableBase类
首先如下所示,OperatorBase是Observable<OperatorBase>(本质上是一个ObserverBase的一个实例类)的一个子类。这里使用了Observer的设计模式。即它考虑将每个Operator都设计为可观察的对象,以便在它执行操作时有相关的Observer
可以去侦测它的执行状态。
class CAFFE2_API OperatorBase;
typedef ObserverBase<OperatorBase> OperatorObserver;
class CAFFE2_API OperatorBase : public Observable<OperatorBase> {
public:
explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
virtual ~OperatorBase() noexcept {}
如下为ObservableBase与Observer的定义,还是蛮简单的,我们可据此完成我们自己对某一类型T的observer。在Caffe2中真正使用Observer来去监控的对象有两个,一个为这里介绍的Operator,另一个则为我们以后将会去分析的network。
template <class T>
class ObserverBase {
public:
explicit ObserverBase(T* subject) : subject_(subject) {}
virtual void Start() {}
virtual void Stop() {}
virtual std::string debugInfo() {
return "Not implemented.";
}
virtual ~ObserverBase() noexcept {};
T* subject() const {
return subject_;
}
virtual std::unique_ptr<ObserverBase<T>> rnnCopy(T* subject, int rnn_order)
const {
return nullptr;
};
protected:
T* subject_;
};
/**
* Inherit to make your class observable.
*/
template <class T>
class Observable {
public:
Observable() = default;
Observable(Observable&&) = default;
Observable& operator =(Observable&&) = default;
virtual ~Observable() = default;
C10_DISABLE_COPY_AND_ASSIGN(Observable);
using Observer = ObserverBase<T>;
/* Returns a reference to the observer after addition. */
const Observer* AttachObserver(std::unique_ptr<Observer> observer) {
CAFFE_ENFORCE(observer, "Couldn't attach a null observer.");
std::unordered_set<const Observer*> observers;
for (auto& ob : observers_list_) {
observers.insert(ob.get());
}
const auto* observer_ptr = observer.get();
if (observers.count(observer_ptr)) {
return observer_ptr;
}
observers_list_.push_back(std::move(observer));
UpdateCache();
return observer_ptr;
}
/**
* Returns a unique_ptr to the removed observer. If not found, return a
* nullptr
*/
std::unique_ptr<Observer> DetachObserver(const Observer* observer_ptr) {
for (auto it = observers_list_.begin(); it != observers_list_.end(); ++it) {
if (it->get() == observer_ptr) {
auto res = std::move(*it);
observers_list_.erase(it);
UpdateCache();
return res;
}
}
return nullptr;
}
private:
inline static void StartObserver(Observer* observer) {
try {
observer->Start();
} catch (const std::exception& e) {
LOG(ERROR) << "Exception from observer: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception from observer: unknown";
}
}
inline static void StopObserver(Observer* observer) {
try {
observer->Stop();
} catch (const std::exception& e) {
LOG(ERROR) << "Exception from observer: " << e.what();
} catch (...) {
LOG(ERROR) << "Exception from observer: unknown";
}
}
.............
.............
private:
// an on-stack cache for fast iteration;
// ideally, inside StartAllObservers and StopAllObservers,
// we should never access observers_list_
Observer* observer_cache_;
size_t num_observers_ = 0;
protected:
std::vector<std::unique_ptr<Observer>> observers_list_;
};
Input/Output blobs及Arguments基本操作
OperatorBase里面包含了些最基本的Operator所需的操作像输入、输出Blob处理(获得或其之上类型查询、Copy等),Operator参数处理等。
以下为它用于参数处理的两个示例API函数。
/** @brief Checks if the operator has an argument of the given name.
*/
inline bool HasArgument(const string& name) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::HasArgument(*operator_def_, name);
}
// Functions that deal with arguments. Basically, this allows us to map an
// argument name to a specific type of argument that we are trying to access.
template <typename T>
inline T GetSingleArgument(const string& name, const T& default_value) const {
CAFFE_ENFORCE(operator_def_, "operator_def was null!");
return ArgumentHelper::GetSingleArgument<OperatorDef, T>(
*operator_def_, name, default_value);
}
下面则为它处理Input/Output blob的一些API函数。其中inputs_/outputs_分别是一个operator上所具有的blobs成员属性。有了之前几章讲过的Blob与Tensor的知识后,这些函数也就比较好懂了。
// Get the inputs and outputs as specific types.
template <typename T>
inline const T& Input(int idx) {
static_assert(
!std::is_same<T, Tensor>::value,
"You should use Input<Tensor>(int, DeviceType) for "
"Tensor.");
DCHECK_LT(idx, inputs_.size());
try {
return inputs_.at(idx)->template Get<T>();
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
enf.AppendMessage(".\nOffending Blob name: ");
enf.AppendMessage(debug_def().input(idx));
enf.AppendMessage(".\n");
}
throw enf;
}
}
// TODO(jerryzh): Remove template
// and the type argument?
// This is to keep the API changes minimal and make refactoring
// a bit easier
template <typename T>
inline const T& Input(int idx, DeviceType type) {
static_assert(
std::is_same<T, Tensor>::value,
"Input(int, DeviceType) is only available for Tensor");
DCHECK_LT(idx, inputs_.size());
try {
// TODO(jerryzh): We'll need to check device type in Get<T>() later
// Get<T>() -> Get<T>(type)
const auto& tensor = inputs_.at(idx)->template Get<T>();
return tensor;
} catch (::caffe2::EnforceNotMet& enf) {
if (has_debug_def()) {
enf.AppendMessage(".\nOffending Blob name: ");
enf.AppendMessage(debug_def().input(idx));
enf.AppendMessage(".\n");
}
throw enf;
}
}
template <typename T>
inline T* Output(int idx, DeviceType type) {
static_assert(
std::is_same<T, Tensor>::value,
"Output(int, DeviceType) is only available for Tensor");
// When you get a Tensor here it is not fully initialized
return BlobGetMutableTensor(outputs_.at(idx), type);
}
inline Tensor*
OutputTensor(int idx, at::IntList dims, at::TensorOptions options) {
CAFFE_ENFORCE_WITH_CALLER(
options.device_opt() != c10::nullopt,
"device must be provided in option.");
return BlobGetMutableTensor(outputs_.at(idx), dims, options);
}
template <typename T>
inline T* Output(int idx, T* allocated) {
outputs_.at(idx)->Reset(allocated);
return allocated;
}
其它与输入、输出相关的API函数还有如下一些utilities函数。(当然这里只是一部分,但大致皆是如此,多是些浅显易懂的,多是像OutputIsType这样直接使用了Blob的成员函数)。
template <typename T>
inline bool OutputIsType(int idx) {
static_assert(
!std::is_same<T, Tensor>::value,
"You should use OutputIsTensorType(int, DeviceType) for "
"Tensor.");
return outputs_.at(idx)->template IsType<T>();
}
inline bool OutputIsTensorType(int idx, DeviceType type) {
return BlobIsTensorType(*outputs_.at(idx), type);
}
inline int InputSize() const {
return inputs_.size();
}
inline int OutputSize() const {
return outputs_.size();
Event处理相关async操作函数
下面一些与Event相关的操作才是真正开始有趣的东西。每个Caffe2 op都有一个Event成员(Event在将来我们也会单独拿出来介绍),用于同步op间的依赖执行操作。它是一个op async执行时保证parent op与child op之间同步的产物。如果读者
熟悉CUDA编程APIs,那么对于cudaEvent一定不陌生,本质上在CUDAOperator当中,它之上的event_就是一个cudaEvent的wrapper。
简单说来,async执行时,我们有些op可能需要去check(Query或Wait)父类的event以确定它是否执行完了(即自己的输入有保证了)。同时它们也要在执行过后(或者是完成了自己op执行命令的提交),更新自己的event,以对自己的childeren operators发挥影响。
virtual void WaitEvent(const Event& ev, int /*stream_id */ = -1) {
ev.Finish();
}
inline void Wait(const OperatorBase& other, int stream_id = -1) {
if (!other.IsEventDisabled()) {
WaitEvent(other.event(), stream_id);
}
}
virtual void WaitEvents(
const std::vector<const Event*>& events,
int /*stream_id*/ = -1) {
for (const auto& ev : events) {
ev->Finish();
}
}
virtual void Finish() {
if (event_) {
event_->Finish();
}
}
const Event& event() const {
CAFFE_ENFORCE(event_, "Event is disabled");
return *event_;
}
Event& event() {
CAFFE_ENFORCE(event_, "Event is disabled");
return *event_;
}
void ResetEvent() {
if (event_) {
event_->Reset();
}
}
void DisableEvent() {
event_ = nullptr;
}
bool IsEventDisabled() const {
return !event_;
}
protected:
virtual void RecordEvent(const char* /*err_msg*/ = nullptr) {
CAFFE_NOT_IMPLEMENTED;
}
void SetEventFinished(const char* err_msg = nullptr) {
if (event_) {
event_->SetFinished(err_msg);
}
}
void SetEventFinishedWithException(const char* err_msg = nullptr) {
if (event_) {
event_->SetFinishedWithException(err_msg);
}
}
下面的两个APIs函数同样与op的async执行相关。在OperatorBase中将它一律简单设为了false即op默认不support这两种feature(至于feature具体为何义,我们将在Operator讲解时进行阐述)。
可见大部分与async执行相关的Op操作都被放入了Operator中,OperatorBase这里有API人,但其实形同虚设。所以如果你在实现自己的Op时,如果想遵循Caffe2中已有一套Op async执行逻辑,那么可将自己op实现为
operator的子类(这是绝大多数时候的正常做法,毕竟caffe2是assume用户会直接遵循自己关于event/op async执行的一套做法的即operator中的那些默认做法);但如果你对某一Device context上的Op async执行
有着与framework original design略为不同的想法时,那么你就可以将op实现为OperatorBase的子类,然后去重载这几个与async执行、Events操作相关的APIs函数。
virtual bool HasAsyncPart() const {
return false;
}
virtual bool SupportsAsyncScheduling() const {
return false;
}
Utility函数
下面是一些用于出错时debug的APIs函数,无关核心。
inline const OperatorDef& debug_def() const {
CAFFE_ENFORCE(has_debug_def(), "operator_def was null!");
return *operator_def_;
}
inline void set_debug_def(
const std::shared_ptr<const OperatorDef>& operator_def) {
operator_def_ = operator_def;
}
inline bool has_debug_def() const {
return operator_def_ != nullptr;
}
public:
void RecordLastFailedOpNetPosition() {
if (net_position_ != kNoNetPositionSet) {
VLOG(1) << "Operator with id " << net_position_ << " failed";
operator_ws_->last_failed_op_net_position = net_position_;
} else {
VLOG(1) << "Failed operator doesn't have id set";
}
}
int net_position() const {
return net_position_;
}
void set_net_position(int idx) {
net_position_ = idx;
}
Caffe2中每个Operator都有Device Context的概念即此Op操作是在哪个device之上进行的。Device的具体信息则可从DeviceOption中获得。
const DeviceOption& device_option() const {
return device_option_;
}
另外每个Operator又可根据其Engine不同有着不一样的实现(在同一Device context下面)。这个设计follow了之前Caffe中的思路(如Layer的类型可以为GPU engine或CPU engine等)。
要知道像其它Caffe2组件如Blob/Tensor/Context等使用注册机制来完成具体定义一样,Operator也是在完成后需要去注册在某一device相关的OperatorRegistry下面的。一般Operator name会
作为此operator实现item的key。而若我们在同一device下面对同一类型op(有同样的名字op_name)有着不同的实现(比如Allreduce操作在CPU设计上可以使用MPI,也可使用Gloo来完成其底层的具体通信aggregation),
那我们就可使用engine来区别这一不同的实现。
下面是与engine相关的一些base operator函数。
void annotate_engine(const std::string& engine) {
engine_ = engine;
}
const std::string& engine() const {
return engine_;
}
像其它许多深度学习Framework一样,Caffe2对于Op操作执行在GPU device上有着很强的假设(very bad assumptions,它让CUDA生态在DL领域过于强大了)。如下函数即为此类假设的一个缩影。
// Checks whether stream is ready to execute new computation,
// used in stream allocation optimization to skip stream that is currently
// busy. Depends on context and operator's device, returns true by default
virtual bool IsStreamFree(int /* unused */) const {
return true;
}
核心的Run及RunAsync函数
最终我们看下关键的Op真正运行的两个函数吧。如下示,其中Run一般为op sync执行时所调用的,而RunAsync则是其async执行时所调用的。可以看出在执行时都默认我们会将操作提交到一stream上,这显然亦是
由CUDA编程API中所用到的cuda_stream的思想而来。。
virtual bool Run(int /* unused */ /*stream_id*/ = 0) {
CAFFE_NOT_IMPLEMENTED;
}
// RunAsync, if implemenented by the specific operators, will schedule the
// computation on the corresponding context and record the event in its
// event_ member object. If the specific operator does not support RunAsync,
// it will simply be synchronous as a fallback.
virtual bool RunAsync(int stream_id = 0) {
try {
auto result = Run(stream_id);
if (result) {
if (HasAsyncPart()) {
RecordEvent();
} else {
SetEventFinished();
}
} else {
SetEventFinished(getErrorMsg().c_str());
}
return result;
} catch (EnforceNotMet& err) {
SetEventFinishedWithException(err.what());
throw;
} catch (const std::exception& err) {
SetEventFinishedWithException(err.what());
throw;
} catch (...) {
SetEventFinishedWithException(getErrorMsg().c_str());
throw;
}
}