// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License.
// Make sure to allocate at the beginning of the block. auto *parentBlock = alloc->getBlock(); alloc->moveBefore(&parentBlock->front());
// Make sure to deallocate this alloc at the end of the block. This is fine // as toy functions have no control flow. auto dealloc = rewriter.create<mlir::memref::DeallocOp>(loc, alloc); dealloc->moveBefore(&parentBlock->back()); return alloc; }
classConstantOpLowering : public mlir::OpRewritePattern<hello::ConstantOp> { using OpRewritePattern<hello::ConstantOp>::OpRewritePattern;
// When lowering the constant operation, we allocate and assign the constant // values to a corresponding memref allocation. auto tensorType = op.getType().cast<mlir::TensorType>(); auto memRefType = convertTensorToMemRef(tensorType); auto alloc = insertAllocAndDealloc(memRefType, loc, rewriter);
// We will be generating constant indices up-to the largest dimension. // Create these constants up-front to avoid large amounts of redundant // operations. auto valueShape = memRefType.getShape(); mlir::SmallVector<mlir::Value, 8> constantIndices;
if (!valueShape.empty()) { for (auto i : llvm::seq<int64_t>( 0, *std::max_element(valueShape.begin(), valueShape.end()))) constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(loc, i)); } else { // This is the case of a tensor of rank 0. constantIndices.push_back(rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0)); } // The constant operation represents a multi-dimensional constant, so we // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, // generating a store when the recursion hits the base case.
// [4, 3] (1, 2, 3, 4, 5, 6, 7, 8) // storeElements(0) // indices = [0] // storeElements(1) // indices = [0, 0] // storeElements(2) // store (const 1) [0, 0] // indices = [0] // indices = [0, 1] // storeElements(2) // store (const 2) [0, 1] // ... // mlir::SmallVector<mlir::Value, 2> indices; auto valueIt = constantValue.getValues<mlir::FloatAttr>().begin(); std::function<void(uint64_t)> storeElements = [&](uint64_t dimension "&") { // The last dimension is the base case of the recursion, at this point // we store the element at the given index. if (dimension == valueShape.size()) { rewriter.create<mlir::AffineStoreOp>( loc, rewriter.create<mlir::arith::ConstantOp>(loc, *valueIt++), alloc, llvm::makeArrayRef(indices)); return; }
// Otherwise, iterate over the current dimension and add the indices to // the list. for (uint64_t i = 0, e = valueShape[dimension]; i != e; ++i) { indices.push_back(constantIndices[i]); storeElements(dimension + 1); indices.pop_back(); } };
// Start the element storing recursion from the first dimension. storeElements(/*dimension=*/0);
// Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); return mlir::success(); } };
classPrintOpLowering : public mlir::OpConversionPattern<hello::PrintOp> { using OpConversionPattern<hello::PrintOp>::OpConversionPattern;
mlir::LogicalResult matchAndRewrite(hello::PrintOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter)constfinal{ // We don't lower "hello.print" in this pass, but we need to update its // operands. rewriter.updateRootInPlace(op, [&] { op->setOperands(adaptor.getOperands()); }); return mlir::success(); } };
namespace { classHelloToAffineLowerPass : public mlir::PassWrapper<HelloToAffineLowerPass, mlir::OperationPass<mlir::ModuleOp>> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(HelloToAffineLowerPass)