MLIR:新建一个Dialect(七),lowering

文章来自微信公众号“科文路”,欢迎关注、互动。转发须注明出处。

Multi-Level Intermediate Representation(MLIR)是创建可重用、可扩展编译器基础设施的新途径。本文为第 13 期,继续介绍一个简单的 MLIR Dialect.

转载请注明出处!

MLIR 项目的核心是 Dialect,MLIR 自身就拥有例如linalgtosaaffine 这些 Dialect。各种不同的 Dialect 使不同类型的优化或转换得以完成。

好了,如果说前面的部分算是 MLIR 的坡道起步,那这一节就要开始弹射起飞了。本期开始讲解 Dialect 的 Lowering,即由 MLIR 代码逐级转换为机器代码的过程。

当然了,前期也提到过,MLIR 生态的目标只在中间阶段,所以其 lowering 本质上并不涉及太多最终的 IR 生成,这一部分更依赖 LLVM 的基石。

内容较多,建议收藏、细品。

复习

工具链、总览等等知识请自行翻看历史 MLIR 标签的相关文章

mlir-hello 项目的目标就是使用自建的 Dialect 通过 MLIR 生态实现一个 hello world,具体做法为:

  1. 创建 hello-opt 将原始 print.mlir (可以理解成 hello world 的 main.cpp)转换为 print.ll 文件
  2. 使用 LLVM 的 lli 解释器直接运行 print.ll 文件

前文主要介绍了如何通过 ODS 实现新的 Dialect/Op 的定义。

Lowering

MLIR 看似清爽,但相关 Pass 的实现一样工作量巨大。

在定义和编写了 HelloDialect 的方方面面后,最终还是要使它们回归 LLVM MLIR “标准库” Dialect,从而再做面向硬件的代码生成。因为标准库中的 Dialect 的剩余工作可以“无痛”衔接到 LLVM 的基础组件上。

具体到 mlir-hello,HelloDialect 到 LLVM 标准库 Dialect,例如 affine dialectllvm dialect 的 lowering 将手工编码完成。

这一部分可能是 MLIR 相关任务工作量最大的地方。

这一篇文章作为 lowering 相关内容的开端,来解读如何通过 C++ 实现 HelloDialectaffine dialect 的 lowering。

相关文件如下:

  • mlir-hello/include/Hello/HelloDialect.h,主要内容通过前期讲过的 ODS 自动生成,略
  • mlir-hello/include/Hello/HelloOps.h,主要内容通过前期讲过的 ODS 自动生成,略
  • mlir-hello/include/Hello/HelloPasses.h,注册本不存在的 lowering pass,比如 Hello 到 Affine 的 pass
  • mlir-hello/lib/Hello/LowerToAffine.cpp,lowering pass 的实现

代码解读

简单讲,Dialect 到 Dialect 是一个 match and rewrite 的过程。

注意,有一个之前介绍过的、在 MLIR 中被大量应用的 C++ 编程技巧可能需要巩固一下:C++:CRTP,传入继承

Pass registration

mlir-hello/include/Hello/HelloPasses.h

通过 std::unique_ptr<mlir::Pass> 在 MLIR 中注册两条 lowering pass。

注册的这个函数钩子将会在下一节的 cpp 中得到具体的实现的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// mlir-hello/include/Hello/HelloPasses.h
// 该文件在 MLIR 中注册两条 lowering pass,没啥特别的

#ifndef MLIR_HELLO_PASSES_H
#define MLIR_HELLO_PASSES_H

#include <memory>

#include "mlir/Pass/Pass.h"

namespace hello {
std::unique_ptr<mlir::Pass> createLowerToAffinePass();
std::unique_ptr<mlir::Pass> createLowerToLLVMPass();
}

#endif // MLIR_HELLO_PASSES_H

Lowering implementation

mlir-hello/lib/Hello/LowerToAffine.cpp

负责 helloaffine 的 lowering 实现,本质上分为各 Op lowering 的前置工作Dialect to Dialect实现两个部分。最终的实现 createLowerToAffinePass 将作为 Pass 注册时函数钩子的返回。

1. Op lowering

例如对于某 Xxx 算子,共性为

  • 定义为 class XxxOpLowering
  • 继承自 mlir::OpRewritePattern<hello::XxxOp>
  • 重载 matchAndRewrite 函数,做具体实现
  • XxxOpLowering 最终将作为模板参数传入新 pass 的 mlir::RewritePatternSet<XxxOpLowering>

例如 class ConstantOpLowering 的实现如下:它会将 ConstantOp 所携带的信息最终转储到 mlir::AffineStoreOp 中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class ConstantOpLowering : public mlir::OpRewritePattern<hello::ConstantOp> {
using OpRewritePattern<hello::ConstantOp>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(hello::ConstantOp op, mlir::PatternRewriter &rewriter) const final {
// 捕获 ConstantOp 的元信息:值、位置
mlir::DenseElementsAttr constantValue = op.getValue();
mlir::Location loc = op.getLoc();

// lowering 时,需要将 constant 的参数转存为 memref
auto tensorType = op.getType().cast<mlir::TensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);

// 预先声明一个“最高维”的变量
auto valueShape = memRefType.getShape();
mlir::SmallVector<mlir::Value, 8> constantIndices;

if (!valueShape.empty()) {
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(loc, i));
} else {
// rank 为 0 时
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0));
}

// ConstantOp 将作为一个“多维常量”被使用,它可能包含下面这些隐含信息(结构、值),
// [4, 3] (1, 2, 3, 4, 5, 6, 7, 8)
// storeElements(0)
// indices = [0]
// storeElements(1)
// indices = [0, 0]
// storeElements(2)
// store (const 1) [0, 0]
// indices = [0]
// indices = [0, 1]
// storeElements(2)
// store (const 2) [0, 1]
// ...

// 于是,可以通过定义一个递归 functor (中文一般译为仿函数)去捕获这些信息。
// functor 的基本思路为,从第一个维度开始,向第 2, 3,...个维度递归取回每个维度上的元素。
mlir::SmallVector<mlir::Value, 2> indices;
auto valueIt = constantValue.getValues<mlir::FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension "&") {
// 递归边界情况:到了最后一维,直接存下整组值
if (dimension == valueShape.size()) {
rewriter.create<mlir::AffineStoreOp>(
loc, rewriter.create<mlir::arith::ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}
// 未到递归边界:在当前维度上挨个儿递归,存储结构信息
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};

// 使用上面的 functor
storeElements(/*dimension=*/0);

// 将 insertAllocAndDealloc 替换为当前 op
rewriter.replaceOp(op, alloc);
return mlir::success();
}
};

2. Dialect to Dialect

定义好 op 的 lowering 后,就可以通过点对点的 lowering pass 说明如何进行 Dialect 之间的转换了。

这里的 class HelloToAffineLowerPass 主要需要实现 runOnOperation 函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
namespace {
// 继承 PassWrapper,定义 HelloToAffineLowerPass,它将作为函数钩子的实现返回到上面的 pass 注册
class HelloToAffineLowerPass : public mlir::PassWrapper<HelloToAffineLowerPass, mlir::OperationPass<mlir::ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLowerPass)

// 依赖哪些标准库里的 Dialect
void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::AffineDialect, mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
}

void runOnOperation() final;
};
}

// 需要实现的函数,它来说明如何做 lowering
void HelloToAffineLowerPass::runOnOperation() {
// 获取上下文
mlir::ConversionTarget target(getContext());

// 在 addIllegalDialect 中将 HelloDialect 置为不合法(需要被lowering)
target.addIllegalDialect<hello::HelloDialect>();
// 说明哪些 Dialect 是合法(lowering目标,通常是标准库中的 Dialect)的
target.addLegalDialect<mlir::AffineDialect, mlir::BuiltinDialect,
mlir::func::FuncDialect, mlir::arith::ArithDialect, mlir::memref::MemRefDialect>();
// 后期可通过 `isDynamicallyLegal` 决定其是否合法,这里具体表现为“当 PrintOp 的参数合法时,它才合法”
target.addDynamicallyLegalOp<hello::PrintOp>([](hello::PrintOp op "") {
return llvm::none_of(op->getOperandTypes(),
[](mlir::Type type "") { return type.isa<mlir::TensorType>(); });
});

// 说明如何 lowering,只需要把 illegal 的 op 的 lowering 实现作为模板参数传入 RewritePatternSet
mlir::RewritePatternSet patterns(&getContext());
patterns.add<ConstantOpLowering, PrintOpLowering>(&getContext());

if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
}
}

本期结语

Pass 的实现确实工作量比较大,但是又不可避免,因为新的 Dialect 到标准库 Dialect 的过程还是必定需要手工实现。这也是很多反对 MLIR 的声音的来源。我们下期继续。

附全部代码

mlir-hello/lib/Hello/LowerToAffine.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "Hello/HelloDialect.h"
#include "Hello/HelloOps.h"
#include "Hello/HelloPasses.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"

static mlir::MemRefType convertTensorToMemRef(mlir::TensorType type) {
assert(type.hasRank() && "expected only ranked shapes");
return mlir::MemRefType::get(type.getShape(), type.getElementType());
}

static mlir::Value insertAllocAndDealloc(mlir::MemRefType type, mlir::Location loc,
mlir::PatternRewriter &rewriter) {
auto alloc = rewriter.create<mlir::memref::AllocOp>(loc, type);

// Make sure to allocate at the beginning of the block.
auto *parentBlock = alloc->getBlock();
alloc->moveBefore(&parentBlock->front());

// Make sure to deallocate this alloc at the end of the block. This is fine
// as toy functions have no control flow.
auto dealloc = rewriter.create<mlir::memref::DeallocOp>(loc, alloc);
dealloc->moveBefore(&parentBlock->back());
return alloc;
}

class ConstantOpLowering : public mlir::OpRewritePattern<hello::ConstantOp> {
using OpRewritePattern<hello::ConstantOp>::OpRewritePattern;

mlir::LogicalResult matchAndRewrite(hello::ConstantOp op, mlir::PatternRewriter &rewriter) const final {
mlir::DenseElementsAttr constantValue = op.getValue();
mlir::Location loc = op.getLoc();

// When lowering the constant operation, we allocate and assign the constant
// values to a corresponding memref allocation.
auto tensorType = op.getType().cast<mlir::TensorType>();
auto memRefType = convertTensorToMemRef(tensorType);
auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);

// We will be generating constant indices up-to the largest dimension.
// Create these constants up-front to avoid large amounts of redundant
// operations.
auto valueShape = memRefType.getShape();
mlir::SmallVector<mlir::Value, 8> constantIndices;

if (!valueShape.empty()) {
for (auto i : llvm::seq<int64_t>(
0, *std::max_element(valueShape.begin(), valueShape.end())))
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(loc, i));
} else {
// This is the case of a tensor of rank 0.
constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0));
}
// The constant operation represents a multi-dimensional constant, so we
// will need to generate a store for each of the elements. The following
// functor recursively walks the dimensions of the constant shape,
// generating a store when the recursion hits the base case.

// [4, 3] (1, 2, 3, 4, 5, 6, 7, 8)
// storeElements(0)
// indices = [0]
// storeElements(1)
// indices = [0, 0]
// storeElements(2)
// store (const 1) [0, 0]
// indices = [0]
// indices = [0, 1]
// storeElements(2)
// store (const 2) [0, 1]
// ...
//
mlir::SmallVector<mlir::Value, 2> indices;
auto valueIt = constantValue.getValues<mlir::FloatAttr>().begin();
std::function<void(uint64_t)> storeElements = [&](uint64_t dimension "&") {
// The last dimension is the base case of the recursion, at this point
// we store the element at the given index.
if (dimension == valueShape.size()) {
rewriter.create<mlir::AffineStoreOp>(
loc, rewriter.create<mlir::arith::ConstantOp>(loc, *valueIt++), alloc,
llvm::makeArrayRef(indices));
return;
}

// Otherwise, iterate over the current dimension and add the indices to
// the list.
for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) {
indices.push_back(constantIndices[i]);
storeElements(dimension + 1);
indices.pop_back();
}
};

// Start the element storing recursion from the first dimension.
storeElements(/*dimension=*/0);

// Replace this operation with the generated alloc.
rewriter.replaceOp(op, alloc);
return mlir::success();
}
};

class PrintOpLowering : public mlir::OpConversionPattern<hello::PrintOp> {
using OpConversionPattern<hello::PrintOp>::OpConversionPattern;

mlir::LogicalResult matchAndRewrite(hello::PrintOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const final {
// We don't lower "hello.print" in this pass, but we need to update its
// operands.
rewriter.updateRootInPlace(op,
[&] { op->setOperands(adaptor.getOperands()); });
return mlir::success();
}
};

namespace {
class HelloToAffineLowerPass : public mlir::PassWrapper<HelloToAffineLowerPass, mlir::OperationPass<mlir::ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLowerPass)

void getDependentDialects(mlir::DialectRegistry &registry) const override {
registry.insert<mlir::AffineDialect, mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
}

void runOnOperation() final;
};
}

void HelloToAffineLowerPass::runOnOperation() {
mlir::ConversionTarget target(getContext());

target.addIllegalDialect<hello::HelloDialect>();
target.addLegalDialect<mlir::AffineDialect, mlir::BuiltinDialect,
mlir::func::FuncDialect, mlir::arith::ArithDialect, mlir::memref::MemRefDialect>();
target.addDynamicallyLegalOp<hello::PrintOp>([](hello::PrintOp op "") {
return llvm::none_of(op->getOperandTypes(),
[](mlir::Type type "") { return type.isa<mlir::TensorType>(); });
});

mlir::RewritePatternSet patterns(&getContext());
patterns.add<ConstantOpLowering, PrintOpLowering>(&getContext());

if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) {
signalPassFailure();
}
}

std::unique_ptr<mlir::Pass> hello::createLowerToAffinePass() {
return std::make_unique<HelloToAffineLowerPass>();
}

都看到这儿了,不如关注每日推送的“科文路”、互动起来~

至少点个赞再走吧~

MLIR:新建一个Dialect(七),lowering

https://xlindo.com/kewenlu2022/posts/24659433/

Author

xlindo

Posted on

2022-12-08

Updated on

2023-05-10

Licensed under

Comments