介绍
dispatch是Pytorch中的一个内部组件,该组件负责将调用一个function(例如torch:add)的时候指出具体执行的代码, 因为PyTorch操作需要处理许多交叉的关注点,这些点是分层的,下面列举了一些:
- 按照输入Tensor的device,在算子的CPU和CUDA实现中转换
- 按照是否进行autograd,在算子的autograd和backend实现中转换
- 是否有必要对混合精度执行autocast
- 是否对运行在vmap call下的算子执行batch rules
- 是否trace算子的执行
PyTorch中用DispatchKey表示不同的关注点
enum class DispatchKey : uint8_t {
Undefined = 0,
CatchAll = Undefined,
CPU, // registered at build/aten/src/ATen/RegisterCPU.cpp
CUDA, // registered at build/aten/src/ATen/RegisterCUDA.cpp
HIP, // NB: I think this is not actually used, due to Note [Masquerading as
// CUDA]
FPGA, // Xilinx support lives out of tree at
ORT,
XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
Metal,
XPU, // For out of tree Intel's heterogeneous computing plug-in
HPU, // For out of tree & closed source integration of HPU / Habana
VE, // For out of tree & closed source integration of SX-Aurora / NEC
Lazy, // For lazy tensor backends
SWAI, // For out of tree SWAI backend
Meta,
QuantizedCPU, // registered at build/aten/src/ATen/RegisterQuantizedCPU.cpp
QuantizedCUDA, // registered at build/aten/src/ATen/RegisterQuantizedCUDA.cpp
QuantizedXPU, // For out of tree Intel's heterogeneous computing plug-in
CustomRNGKeyId,
...
};
总的来说,dispatch解决了一个问题:该调用哪个kernel
简单地,像下面的例子一样,使用if语句就可以处理多种情况
class MyAddFunction : ... {
public:
static Tensor forward(
AutogradContext *ctx, torch::Tensor self, torch::Tensor other) {
if (self.device().type() == DeviceType::CPU) {
return add_cpu(self, other);
} else if (self.device().type() == DeviceType::CUDA) {
return add_cuda(self, other);
} else {
TORCH_CHECK(0, "Unsupported device ", self.device().type());
}
}
...
}
那么为什么要使用dispatch
- 去中心化的, 对于任意一个新的operator,不需要写一个单独的if语句去判断。此外,当第三方实现一个算子在不同情况下(例如设备)的实现时,不需要修补算子的原有实现。
- 除了CPU,CUDA,Autograd,dispatch key支持更多的关注点, c10/core/DispatchKey.h已经实现了一系列的dispatch key
- 实现了对boxed fallback functions的支持,这些函数一次实现,能够应用于所有的算子。
Dispatch分发机制
概念
首先先定义一下一些概念
operator:算子,例如add
kernels:核函数,算子在不同设备(CPU,CUDA),不同输入(dense,sparse),是否梯度下的不同实现
思路
Dispatch机制是将if判断转换成映射的机制,底层是通过hashmap实现,Dispatch控制所有operator的分发
第一层分发,通过operator name映射到 OperatorHandle (每一个operator都有一个OperatorHandle类处理)
第二层分发,通过dispatch key映射到 kernel function (不同的设备,不同输入..都对应于一个dispatch key)
代码实现
dispatch主要代码在aten/src/ATen/core/dispatch目录下
主要的类有
● Dispatch 处理op name到Operator Handle的映射
● OperatorHandle 注册,查找,调用具体operator的kernel
● OperatorEntry 处理Dispatch 可以到 KernelFunction映射
● KernelFunction 封装backend kernel
Dispatch类
operatorLookupTable_表存放了operator name到Operator Handle的映射
class TORCH_API Dispatcher final{
std::list<OperatorDef> operators_;
LeftRight<ska::flat_hash_map<OperatorName, OperatorHandle>> operatorLookupTable_;
};
下面的这些方法都是通过operator name查找到OperatorHanle,具体实现可以查看源码
c10::optional<OperatorHandle> findSchema(const OperatorName& operator_name);
OperatorHandle findSchemaOrThrow(const char* name, const char* overload_name);
c10::optional<OperatorHandle> findOp(const OperatorName& operator_name);
const std::vector<OperatorName> getAllOpNames();
OperatorHandle类
OperatorHandle类包含OperatorDef,OperatorDef包含OperatorEntry, 具体的映射关系由OperatorEntry处理
class TORCH_API OperatorHandle {
Dispatcher::OperatorDef* operatorDef_;
void callBoxed(Stack* stack) const {
c10::Dispatcher::singleton().callBoxed(*this, stack);
}
void callBoxed(Stack& stack) const {
callBoxed(&stack);
}
void redispatchBoxed(DispatchKeySet ks, Stack* stack) const {
c10::Dispatcher::singleton().redispatchBoxed(*this, ks, stack);
}
}
OperatorEntry类
和Dispatch一样,OperatorEntry也有一个映射表,存储dispatch key到kernelfunction的映射关系,dispatch key是一个unit8_t的枚举值,因此在这里用array实现了映射表
class TORCH_API OperatorEntry final {
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
};
通过查找dispatch key返回kernel function
const KernelFunction& lookup(DispatchKey k) const {
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
// A valid kernel *always* has a boxed kernel and *may* have an
// unboxed kernel. However, we typically do unboxed calls in at::
// APIs, where the kernel 1) will very likely be valid and 2)
// should have an unboxed kernel. Checking the unboxed kernel
// first will allow us to avoid touching the boxed kernel at all
// in the common case.
if (C10_UNLIKELY(!kernel.isValidUnboxed())) {
if (!kernel.isValid()) {
reportError(k);
}
}
return kernel;
}
KernelFunction类
KernelFunction封装了backend kernel和boxed kernel,unboxed_kernel
functor_指向了backend kernel function
class TORCH_API KernelFunction final {
OperatorKernel* getFunctor_() const;
std::shared_ptr<OperatorKernel> functor_;
InternalBoxedKernelFunction* boxed_kernel_func_;
void* unboxed_kernel_func_;
};
call最后是调用functor_的具体实现
template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
// note: Args above is intentionally not Args&&. We don't want perfect
// forwarding, which would require Args to be deduced, but instead we
// want callers to explicitly specify the Args.
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
boxed_kernel_func_ != nullptr,
"Tried to call KernelFunction::call() on an uninitialized KernelFunction."
);
return impl::BoxedKernelWrapper<Return(Args...)>::call(
boxed_kernel_func_,
functor_.get(),
opHandle,
dispatchKeySet,
std::forward<Args>(args)...
);
}
Dispatch注册机制
注册Operator
注册operator的案例如下
TORCH_LIBRARY(myops, m) { m.def("myadd(Tensor self, Tensor other) -> Tensor"); }
宏定义在torch/library中,追踪代码,具体的实现在aten/src/ATen/core/library.cpp中
#define DEF_PRELUDE "def(\"", schema.operator_name(), "\"): "
Library& Library::_def(c10::FunctionSchema&& schema, c10::OperatorName* out_name) & {
TORCH_CHECK(kind_ == DEF || kind_ == FRAGMENT,
DEF_PRELUDE,
"Cannot define an operator inside of a ", toString(kind_), " block. "
"All def()s should be placed in the (unique) TORCH_LIBRARY block for their namespace. ",
ERROR_CONTEXT
);
TORCH_INTERNAL_ASSERT(ns_.has_value(), ERROR_CONTEXT);
TORCH_INTERNAL_ASSERT(!dispatch_key_.has_value(), ERROR_CONTEXT);
auto ns_opt = schema.getNamespace();
if (ns_opt.has_value()) {
// Note [Redundancy in registration code is OK]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In an earlier version of this code, I made it an error to explicitly
// specify the namespace, even when the namespaces match. I've decided
// to relax this constraint because sometimes we code generate registrations
// and you cannot conveniently tell what the enclosing context will be;
// in these cases, it is simpler (and less error prone) to place all
// of the information in the registration site, which will be cross-checked
// in the end in any case (and if it turns out you DON'T have the right
// information at the site, as is the case with backend specific
// per-op registrations, you will get the right behavior!)
TORCH_CHECK(*ns_opt == *ns_,
"Explicitly provided namespace (", *ns_opt, ") in schema string "
"does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
"Move this definition to the (unique) TORCH_LIBRARY block corresponding to this namespace "
"(and consider deleting the namespace from your schema string.) ",
ERROR_CONTEXT
);
} else {
bool b = schema.setNamespaceIfNotSet(ns_->c_str());
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
}
if (out_name) {
*out_name = schema.operator_name(); // copy!
}
registrars_.emplace_back(
c10::Dispatcher::singleton().registerDef(
std::move(schema),
debugString("", file_, line_)
)
);
return *this;
}
最后调用Dispatcher的registerDef方法,该方法映射operator name和operator的关系,也就是执行插入映射表的操作
RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug) {
// we need a lock to avoid concurrent writes
std::lock_guard<std::mutex> lock(mutex_);
OperatorName op_name = schema.operator_name();
auto op = findOrRegisterName_(op_name);
TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
" Each overload's schema should only be registered with a single call to def().",
" Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug));
listeners_->callOnOperatorRegistered(op);
// NB: do not increment the counts until AFTER error checking
++op.operatorDef_->def_count;
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII([this, op, op_name] {
deregisterDef_(op, op_name);
});
}
注册kernel
下面的代码将myadd算子在CPU上实现的kernel注册到Dispatch中
TORCH_LIBRARY_IMPL(myops, CPU, m) {
m.impl("myadd", myadd_cpu);
}
和注册operator一样,追踪注册kernel的宏定义代码,具体的实现为
#define IMPL_PRELUDE "impl(\"", name_str, "\", ...): "
Library& Library::_impl(const char* name_str, CppFunction&& f) & {
auto name = torch::jit::parseName(name_str);
auto ns_opt = name.getNamespace();
// This is kind of similar to the checking in def(), but the error
// messages are a little different for this call site
if (ns_opt.has_value()) {
// See Note [Redundancy in registration code is OK]
TORCH_CHECK(*ns_opt == *ns_,
IMPL_PRELUDE,
"Explicitly provided namespace (", *ns_opt, ") in operator name "
"does not match namespace of enclosing ", toString(kind_), " block (", *ns_, "). "
"Move this definition to the ", toString(kind_), " block corresponding to this namespace "
"(and consider deleting the namespace from your schema string.) ",
ERROR_CONTEXT
);
} else {
bool b = name.setNamespaceIfNotSet(ns_->c_str());
TORCH_INTERNAL_ASSERT(b, ERROR_CONTEXT);
}
// See Note [Redundancy in registration code is OK]
TORCH_CHECK(!(f.dispatch_key_.has_value() &&
dispatch_key_.has_value() &&
*f.dispatch_key_ != *dispatch_key_),
IMPL_PRELUDE,
"Explicitly provided dispatch key (", *f.dispatch_key_, ") is inconsistent "
"with the dispatch key of the enclosing ", toString(kind_), " block (", *dispatch_key_, "). "
"Please declare a separate ", toString(kind_), " block for this dispatch key and "
"move your impl() there. "
ERROR_CONTEXT
);
auto dispatch_key = f.dispatch_key_.has_value() ? f.dispatch_key_ : dispatch_key_;
registrars_.emplace_back(
c10::Dispatcher::singleton().registerImpl(
std::move(name),
dispatch_key,
std::move(f.func_),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(f.cpp_signature_),
std::move(f.schema_),
debugString(std::move(f.debug_), file_, line_)
)
);
return *this;
}
调用Dispatcher::registerImpl方法
RegistrationHandleRAII Dispatcher::registerImpl(
OperatorName op_name,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<impl::CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
std::lock_guard<std::mutex> lock(mutex_);
auto op = findOrRegisterName_(op_name);
auto handle = op.operatorDef_->op.registerKernel(
*this,
dispatch_key,
std::move(kernel),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(cpp_signature),
std::move(inferred_function_schema),
std::move(debug)
);
++op.operatorDef_->def_and_impl_count;
return RegistrationHandleRAII([this, op, op_name, dispatch_key, handle] {
deregisterImpl_(op, op_name, dispatch_key, handle);
});
}
找到对应的operatorHandle,调用OperatorEntry的registerKernel方法
registerKernel关键的是在67行到71行的更新映射表方法
OperatorEntry::AnnotatedKernelContainerIterator OperatorEntry::registerKernel(
const c10::Dispatcher& dispatcher,
c10::optional<DispatchKey> dispatch_key,
KernelFunction kernel,
c10::optional<CppSignature> cpp_signature,
std::unique_ptr<FunctionSchema> inferred_function_schema,
std::string debug
) {
// NB: cpp_signature doesn't get cleared even after the kernel that populated
// it is deleted. This means you could poison the value of cpp_signature_
// with a bad signature value, and then it would permanently stay there until
// you deregister the schema. This can't really be fixed, because we
// only do a typed() test once in the lifetime of a TypedOperatorHandle,
// which means if you could validly change the type of a cpp_signature, then
// that would also invalidate the old TypedOperatorHandles.
if (cpp_signature.has_value()) {
if (cpp_signature_.has_value()) {
TORCH_CHECK(*cpp_signature == cpp_signature_->signature,
"\nMismatch in kernel C++ signatures\n",
" operator: ", (this->schema_.has_value() ? toString(this->schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" kernel 1: ", cpp_signature_->signature.name(), "\n",
" dispatch key: ", toString(cpp_signature_->dispatch_key), "\n",
" ", cpp_signature_->debug, "\n",
" kernel 2: ", cpp_signature->name(), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" ", debug, "\n"
);
} else {
cpp_signature_ = CppSignatureWithDebug { *cpp_signature, debug, dispatch_key };
}
}
if (schema_ && inferred_function_schema) {
checkSchema(name_, schema_->schema, schema_->debug, *inferred_function_schema, debug);
}
// Add the kernel to the kernels list,
// possibly creating the list if this is the first kernel.
// Redirect catchAll registrations to CompositeImplicitAutograd.
auto& k = dispatch_key.has_value() ? kernels_[*dispatch_key] : kernels_[DispatchKey::CompositeImplicitAutograd];
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
if (k[0].kernel.isValid()) {
#else
if (k.size() > 0) {
#endif
TORCH_WARN("Overriding a previously registered kernel for the same operator and the same dispatch key\n",
" operator: ", (schema_.has_value() ? toString(schema_->schema) : toString(name_)), "\n",
" ", (this->schema_.has_value() ? this->schema_->debug : "no debug info"), "\n",
" dispatch key: ", toString(dispatch_key), "\n",
" previous kernel: ", (cpp_signature_.has_value() ? cpp_signature_->debug : "no debug info"), "\n",
" new kernel: ", debug
);
}
#ifdef C10_DISPATCHER_ONE_KERNEL_PER_DISPATCH_KEY
k[0].kernel = std::move(kernel);
k[0].inferred_function_schema = std::move(inferred_function_schema);
k[0].debug = std::move(debug);
#else
k.emplace_front(std::move(kernel), std::move(inferred_function_schema), std::move(debug));
#endif
AnnotatedKernelContainerIterator inserted = k.begin();
// update the dispatch table, i.e. re-establish the invariant
// that the dispatch table points to the newest kernel
if (dispatch_key.has_value()) {
updateDispatchTable_(dispatcher, *dispatch_key);
} else {
updateDispatchTableFull_(dispatcher);
}
return inserted;
}
Dispatch调用过程
build/aten/src/ATen/Functions.h包含了算子的入口函数
以torch::relu为例描述从入口函数到最终backend kernel function的调用过程
TORCH_API inline at::Tensor relu(const at::Tensor & self) {
return at::_ops::relu::call(self);
}
build/aten/src/ATen/Operators_4.cpp
通过Operator name查找到TypedOperatorHandle对象,然后调用call方法
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, name, "aten::relu")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, overload_name, "")
STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(relu, schema_str, "relu(Tensor self) -> Tensor")
// aten::relu(Tensor self) -> Tensor
static C10_NOINLINE c10::TypedOperatorHandle<relu::schema> create_relu_typed_handle() {
return c10::Dispatcher::singleton()
.findSchemaOrThrow(relu::name, relu::overload_name)
.typed<relu::schema>();
}
// aten::relu(Tensor self) -> Tensor
at::Tensor relu::call(const at::Tensor & self) {
static auto op = create_relu_typed_handle();
return op.call(self);
}
// aten::relu(Tensor self) -> Tensor
at::Tensor relu::redispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) {
static auto op = create_relu_typed_handle();
return op.redispatch(dispatchKeySet, self);
}
TypedOperatorHandle call方法调用Dispatch call
C10_ALWAYS_INLINE Return call(Args... args) const {
return c10::Dispatcher::singleton().call<Return, Args...>(*this, std::forward<Args>(args)...);
}
Dispatch call方法通过dispatchKey找到KernelFunction,然后调用KernelFunction的call方法
template<class Return, class... Args>
C10_DISPATCHER_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandle<Return(Args...)>& op, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
auto dispatchKeySet = op.operatorDef_->op.dispatchKeyExtractor()
.template getDispatchKeySetUnboxed<Args...>(args...);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!c10::isAliasDispatchKey(dispatchKeySet.highestPriorityTypeId()));
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet.highestPriorityTypeId());
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING
// By default, when there're no high-frequency or non-sampled callbacks,
// RecordFunction is pre-sampled as a perf optimization;
// shouldRunRecordFunction checks whether RecordFunction should be executed,
// and sets pre_sampled boolean argument value to whether pre-sampling was used -
// this boolean is passed into RecordFunction to adjust the sampling rates of
// the callbacks
bool pre_sampled = false;
if (C10_UNLIKELY(at::shouldRunRecordFunction(&pre_sampled))) {
return callWithDispatchKeySlowPath<Return, Args...>(op, pre_sampled, dispatchKeySet, kernel, std::forward<Args>(args)...);
}
#endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
}
KernelFunction最终调用backend kernel function
template<class Return, class... Args>
C10_ALWAYS_INLINE Return KernelFunction::call(const OperatorHandle& opHandle, DispatchKeySet dispatchKeySet, Args... args) const {
// note: Args above is intentionally not Args&&. We don't want perfect
// forwarding, which would require Args to be deduced, but instead we
// want callers to explicitly specify the Args.
if (C10_LIKELY(unboxed_kernel_func_ != nullptr)) {
return callUnboxedKernelFunction<Return, Args...>(unboxed_kernel_func_, functor_.get(), dispatchKeySet, std::forward<Args>(args)...);
}
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
boxed_kernel_func_ != nullptr,
"Tried to call KernelFunction::call() on an uninitialized KernelFunction."
);
return impl::BoxedKernelWrapper<Return(Args...)>::call(
boxed_kernel_func_,
functor_.get(),
opHandle,
dispatchKeySet,
std::forward<Args>(args)...
);
}
autograd
Autograd在dispatch key的优先级目前是最高的,大部分的operator都有autograd过程,因此每个operator call首次进入的backend kernel function是它的autograd function, 还是以relu为例,它的autograd function如下
at::Tensor relu(c10::DispatchKeySet ks, const at::Tensor & self) {
auto& self_ = unpack(self, "self", 0);
auto _any_requires_grad = compute_requires_grad( self );
(void)_any_requires_grad;
std::shared_ptr<ReluBackward0> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( self ));
}
#ifndef NDEBUG
c10::optional<Storage> self__storage_saved =
self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
c10::intrusive_ptr<TensorImpl> self__impl_saved;
if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
#endif
auto _tmp = ([&]() {
at::AutoDispatchBelowADInplaceOrView guard;
return at::redispatch::relu(ks & c10::after_autograd_keyset, self_);
})();
auto result = std::move(_tmp);
#ifndef NDEBUG
if (self__storage_saved.has_value())
AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
if (result.has_storage()) AT_ASSERT(result.storage().use_count() == 1, "function: relu");
AT_ASSERT(result.use_count() <= 1, "function: relu");
#endif
if (grad_fn) {
set_history(flatten_tensor_args( result ), grad_fn);
}
throw_error_for_complex_autograd(result, "relu");
TORCH_CHECK_NOT_IMPLEMENTED(!(isFwGradDefined(self)), "Trying to use forward AD with relu that does not support it.");
if (grad_fn) {
grad_fn->result_ = SavedVariable(result, true);
}
return result;
}
autograd函数处理完一些最后调用的是at::redispatch::relu方法,进行重新dispatch过程。经过了autograd了,redispatch的dispatch key也更新了,通过调用链后执行的是relu前向计算的backend kernel function
template<class Return, class... Args>
inline Return Dispatcher::redispatch(const TypedOperatorHandle<Return (Args...)>& op, DispatchKeySet currentDispatchKeySet, Args... args) const {
detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5
// do not use RecordFunction on redispatch
const KernelFunction& kernel = op.operatorDef_->op.lookup(currentDispatchKeySet.highestPriorityTypeId());
return kernel.template call<Return, Args...>(op, currentDispatchKeySet, std::forward<Args>(args)...);
}