概述
MLIR Toy Tutorial 的目标是通过构建一门编程语言编译器的完整过程(包括前端和后端技术),教授如何使用 MLIR 的各个组件来实现语言的解析、转换和代码生成等功能。
Chapter3 介绍了如何在 Canonicalizer pass 上应用自定义的 rewrite pattern 来重写优化 IR。回顾一下相关开发流程:
- 自定义继承于
mlir::OpRewritePattern
的 RewritePattern,在MatchAndRewrite()
中定义匹配和重写的规则。 - 在 op 定义(
Ops.td
)中声明let hasCanonicalizer = 1;
,并实现getCanonicalizationPatterns()
方法,用它来调用那些自定义的 RewritePattern。 - 通过
mlir::createCanonicalizerPass()
创建 Canonicalizer pass,pass 会遍历所有 op,并调用它们的钩子(hook)函数 -- getCanonicalizationPatterns() 来重写 IR。
总之,Canonicalizer pass 通过提供 getCanonicalizationPatterns() 这个 hook(也就是回调函数)可以让各 op 灵活滴实现自定义操作。将这个设计进一步发扬光大就成了 Chapter4 要介绍的技术:接口(interfaces)。
这里说的接口和面向对象里的接口概念差不多,是一种将一组操作组织在一起的抽象机制,使得不同的实体可以共享相同的操作集合,从而实现代码的重用和灵活性。MLIR 为 pass、dialect 和 op 都提供有各种操作接口,用户只为这些组件实现相应的接口方法就可以应用这些操作。
Chapter4 以形状推断(shape inference)为例,介绍如何使用接口。
Shape Inference
形状推断是编译时的重要环节,它可以在运行时根据输入张量构建出一张静态计算图,用于后续如常量折叠、剪枝等图优化操作。它通常需要遍历所有的 op,根据 op 的输入张量的类型(包括形状和数据类型)来推断输出张量的类型,因此,每个 op 都要对外提供一个能做形状推断的方法,在Ch4,这个方法是 inferShapes()
。
Ch4 创建了一个 pass -- ShapeInferencePass
来遍历所有函数的 ops,找出会输出动态形状的 op,然后调用它们的 inferShapes() 来做形状推断。出于简单和教学的需要,Ch4 会先把所有函数中都内联(inlining)到 main 函数,这样 pass 只需要处理一个函数即可。
Inlining
内联函数是一个常用的操作,MLIR 自然也已经提供有相关功能模块,允许不同方言(dialect)之间的 ops 进行内联。Toy dialect 只要应用相关接口 -- DialectInlinerInterface
,就能复用这些功能。
要应用 DialectInlinerInterface 接口,除了需要在 dialect 里声明使用该接口,还需要实现它的方法:isLegalToInline()、handleTerminator() 和 materializeCallConversion() 等:
/// Dialect initialization, the instance will be owned by the context. This is
/// the point of registration of types and operations for the dialect.
void ToyDialect::initialize() {
...
addInterfaces<ToyInlinerInterface>();
}
/// This class defines the interface for handling inlining with Toy
/// operations.
struct ToyInlinerInterface : public DialectInlinerInterface {
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
return true;
}
...
Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const final {
return builder.create<CastOp>(conversionLoc, resultType, input);
}
};
- isLegalToInline() 用于判断两组 ops 做内联是否合法。
- handleTerminator() 用来处理
toy.return
op。 - materializeCallConversion() 用来函数形参和实参的类型转换。
除此之外,由于内联函数需要确定被内联的对象(Callable
)和 ops插入的位置(Caller
),MLIR 也提供了相关接口 -- CallOpInterface
/CallableOpInterface
,来定义 ops 的角色。另外,被内联的函数存在形参和实参类型不匹配的情况,需要引入 CastOp
来对它们做转换。
内联部分涉及的内容较多,由于篇幅所限,这里不做过多展开,详情请看原文。
ShapeInferenceOpInterface
现在我们已经内联了所有的函数,剩下的就只有主函数,它包含静态和动态形状的操作:
toy.func @main() {
%0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%1 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
%2 = toy.cast %1 : tensor<2x3xf64> to tensor<*xf64>
%3 = toy.cast %0 : tensor<2x3xf64> to tensor<*xf64>
%4 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
%5 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
%6 = toy.mul %4, %5 : tensor<*xf64>
toy.print %6 : tensor<*xf64>
toy.return
}
前面说过,ShapeInferencePass 会遍历 main 函数的所有含有动态形状(tensor<*xf64>
)的 op,然后调用它们的 inferShapes
方法来推断出静态形状(如 tensor<2x3xf64>
)。
Ch4 定义了 ShapeInferenceOpInterface 接口,这个接口会提供 inferShapes
方法,这样 Ops 只要应用该接口并实现 inferShapes() 即可:
def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
...
let methods = [
InterfaceMethod<"Infer and set the output shape for the current operation.",
"void", "inferShapes">
];
}
def MulOp : Toy_Op<"mul",
[..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
...
}
void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
这里,MulOp::inferShapes()
会将它的输出张量的数据类型和形状设置成和输入张量一样。
总结
Chapter4 以内联函数和形状推断为例,介绍了 MLIR 的接口技术的使用方法。MLIR 中的接口是一种抽象机制,用于定义和共享操作集合,促进代码的重用和模块化,从而提高代码的灵活性和可维护性。 MLIR pass、dialect 和 op 都可以定义和使用接口。
END
`