MLIR:新建一个Dialect(五),通过.td定义新Op

文章来自微信公众号“科文路”,欢迎关注、互动。转发须注明出处。

Multi-Level Intermediate Representation(MLIR)是创建可重用、可扩展编译器基础设施的新途径。本文为第 11 期,继续介绍一个简单的 MLIR Dialect.

转载请注明出处!

MLIR 项目的核心是 Dialect,MLIR 自身就拥有例如linalgtosaaffine 这些 Dialect。各种不同的 Dialect 使不同类型的优化或转换得以完成。

接上回,本文继续新建一个 Dialect的内容。本文开始解析项目的各个实现部分之一——通过.td定义新 “Hello_Op”。

复习

工具链、总览等等知识请自行翻看历史 MLIR 标签的相关文章

前文介绍了mlir-hello 项目的目标就是使用自建的 Dialect 通过 MLIR 生态实现一个 hello world,具体做法为:

  1. 创建 hello-opt 将原始 print.mlir (可以理解成 hello world 的 main.cpp)转换为 print.ll 文件
  2. 使用 LLVM 的 lli 解释器直接运行 print.ll 文件

HelloOps.td

hello.print 作为一个 Op,显而易见,hello Dialect、print Op 都需要被定义。

本文来看看如何定义一个保存变量的ConstantOp和执行打印操作的PrintOp,也就是实际 MLIR 使用中的 hello.constanthello.print

通过声明式的 .td 文件以及 TableGen 工具可以便捷的生成相应的 C++ 代码。

更详细的语法可以在 Operation Definition Specification (ODS)找到。

代码来自 [mlir-hello]/include/Hello/HelloOps.td

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
#ifndef HELLO_OPS
#define HELLO_OPS

include "HelloDialect.td"
// 包含 NoSideEffect 的 trait,不主动做某些消除优化
include "mlir/Interfaces/SideEffectInterfaces.td"

// 第一个 Op,用以转换输入为内部使用的 SSA 值
// 类似 Dialect 中定义的 class HelloOp (对象)
// 实际名字为 constant(CRTP)
def ConstantOp : Hello_Op<"constant", [Pure]> {
// 一行关于这个 Op 的介绍
let summary = "constant";
// 更详细的关于这个 Op 的介绍
let description = [{
Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:

\`\`\`mlir
%0 = "hello.constant"()
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
: () -> tensor<2x3xf64>
\`\`\`
}];

// 每个 Op 都会有的 builder 方法们 https://mlir.llvm.org/docs/OpDefinitions/#builder-methods
let builders = [
// 重载的 build 函数,定义参数。例如下面会生成类似
// `static void build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, mlir::DenseElementsAttr value);`
// 样的代码。前面的 ins 指示 dag-type
OpBuilder<(ins "mlir::DenseElementsAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>,
OpBuilder<(ins "double":$value)>
];

// let parser = [{ return ::parseConstantOp(parser, result); }];
// 定义输入,类似上面的 builder。可以是 operands 或 attributes,这里是后者。前者意思是由其他 operation 产生的值
let arguments = (ins F64ElementsAttr:$value);
// 定义输出
let results = (outs F64Tensor);
}

// 第二个 Op,用以表明打印操作
def PrintOp : Hello_Op<"print", [Pure]> {
let summary = "print operation";
let description = [{
The "print" builtin operation prints a given input tensor, and produces
no results.
}];

// The print operation takes an input tensor to print.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
// 手动写明这个 Op 的输出
let assemblyFormat = "$input attr-dict `:` type($input)";
}

#endif // HELLO_OPS

TableGen

来看看这个 .td 能生成什么样子的代码?

$MLIR_TBLGEN -gen-op-decls include/Hello/HelloOps.td -I$LOCAL_MLIR/include -Iinclude/Hello >> HelloOps.decls.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Op Declarations *|
|* *|
|* Automatically generated file, do not edit! *|
|* *|
\*===----------------------------------------------------------------------===*/

#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)
#undef GET_OP_FWD_DEFINES
namespace hello {
class ConstantOp;
} // namespace hello
namespace hello {
class PrintOp;
} // namespace hello
#endif

#ifdef GET_OP_CLASSES
#undef GET_OP_CLASSES


//===----------------------------------------------------------------------===//
// Local Utility Method Definitions
//===----------------------------------------------------------------------===//

namespace hello {

//===----------------------------------------------------------------------===//
// ::hello::ConstantOp declarations
//===----------------------------------------------------------------------===//

class ConstantOpAdaptor {
public:
ConstantOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});

ConstantOpAdaptor(ConstantOp op);

::mlir::ValueRange getOperands();
std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::ValueRange getODSOperands(unsigned index);
::mlir::DictionaryAttr getAttributes();
::mlir::DenseElementsAttr getValueAttr();
::mlir::DenseElementsAttr getValue();
::mlir::LogicalResult verify(::mlir::Location loc);
private:
::mlir::ValueRange odsOperands;
::mlir::DictionaryAttr odsAttrs;
::mlir::RegionRange odsRegions;
::llvm::Optional<::mlir::OperationName> odsOpName;
};
class ConstantOp : public ::mlir::Op<ConstantOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::TensorType>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
using Adaptor = ConstantOpAdaptor;
public:
static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")};
return ::llvm::makeArrayRef(attrNames);
}

::mlir::StringAttr getValueAttrName() {
return getAttributeNameForIndex(0);
}

static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) {
return getAttributeNameForIndex(name, 0);
}

static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("hello.constant");
}

std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::Operation::operand_range getODSOperands(unsigned index);
std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
::mlir::Operation::result_range getODSResults(unsigned index);
::mlir::DenseElementsAttr getValueAttr();
::mlir::DenseElementsAttr getValue();
void setValueAttr(::mlir::DenseElementsAttr attr);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, mlir::DenseElementsAttr value);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
::mlir::LogicalResult verifyInvariantsImpl();
::mlir::LogicalResult verifyInvariants();
void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
private:
::mlir::StringAttr getAttributeNameForIndex(unsigned index) {
return getAttributeNameForIndex((*this)->getName(), index);
}

static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) {
assert(index < 1 && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
return name.getRegisteredInfo()->getAttributeNames()[index];
}

public:
};
} // namespace hello
MLIR_DECLARE_EXPLICIT_TYPE_ID(::hello::ConstantOp)

namespace hello {

//===----------------------------------------------------------------------===//
// ::hello::PrintOp declarations
//===----------------------------------------------------------------------===//

class PrintOpAdaptor {
public:
PrintOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});

PrintOpAdaptor(PrintOp op);

::mlir::ValueRange getOperands();
std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::ValueRange getODSOperands(unsigned index);
::mlir::Value getInput();
::mlir::DictionaryAttr getAttributes();
::mlir::LogicalResult verify(::mlir::Location loc);
private:
::mlir::ValueRange odsOperands;
::mlir::DictionaryAttr odsAttrs;
::mlir::RegionRange odsRegions;
::llvm::Optional<::mlir::OperationName> odsOpName;
};
class PrintOp : public ::mlir::Op<PrintOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::ZeroResults, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
using Adaptor = PrintOpAdaptor;
public:
static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
return {};
}

static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("hello.print");
}

std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::Operation::operand_range getODSOperands(unsigned index);
::mlir::Value getInput();
::mlir::MutableOperandRange getInputMutable();
std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
::mlir::Operation::result_range getODSResults(unsigned index);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
::mlir::LogicalResult verifyInvariantsImpl();
::mlir::LogicalResult verifyInvariants();
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
void print(::mlir::OpAsmPrinter &_odsPrinter);
void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
public:
};
} // namespace hello
MLIR_DECLARE_EXPLICIT_TYPE_ID(::hello::PrintOp)


#endif // GET_OP_CLASSES

TLDR

这些只是 Op 的声明,还有定义,太多了就不放这里了。可以通过下面指令生成,

$MLIR_TBLGEN -gen-op-defs include/Hello/HelloOps.td -I$LOCAL_MLIR/include -Iinclude/Hello >> HelloOps.defs.h

手写这些代码还是写个 .td 自动生成,你看着办吧~😎

本期结语

本文对 mlir-hello 项目的源代码文件 HelloOps.td 进行了学习,通过自定义的 .td 文件声明式的语法可以在新的 Dialect 中便捷的定义新的 Op。我们下期继续。

都看到这儿了,不如关注每日推送的“科文路”、互动起来~

至少点个赞再走吧~

MLIR:新建一个Dialect(五),通过.td定义新Op

https://xlindo.com/kewenlu2022/posts/e75331b8/

Author

xlindo

Posted on

2022-11-17

Updated on

2023-05-10

Licensed under

Comments