blob: 82dfe137845b4f557609a3df2b7187e2a0132f61 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file provides basic utilities for the elemental lowering of
// each node
#include "mlir-hlo/Dialect/lhlo/transforms/lhlo_elemental_utils.h"
#include "llvm/Support/Debug.h"
#include "mlir-hlo/Dialect/lhlo/IR/lhlo_ops.h"
#include "mlir-hlo/Dialect/lhlo/transforms/map_lmhlo_to_scalar_op.h"
#include "mlir-hlo/utils/codegen_utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
using mlir::memref::DimOp;
using mlir::memref::LoadOp;
using mlir::memref::StoreOp;
namespace mlir {
namespace lmhlo {
Value createLoadOrUseCachedValue(Location loc, OpBuilder* b, Value memref,
ValueRange indices,
OpBuilder::InsertPoint insertPoint) {
// Check if there are any cached value that can be reused,
// within the current Block. Alternatively we can do this for
// all the Blocks that dominant this Block, but that will be
// complicated anyway.
std::vector<StoreOp> storeOps;
insertPoint.getBlock()->walk(
insertPoint.getBlock()->begin(), insertPoint.getPoint(),
[&](StoreOp storeOp) {
if (storeOp.getOperation()->getBlock() != insertPoint.getBlock())
return;
if ((storeOp.getMemRef() == memref) &&
(storeOp.getIndices() == indices))
storeOps.emplace_back(storeOp);
});
if (!storeOps.empty()) return storeOps[0].getOperand(0);
int rank = memref.getType().dyn_cast<MemRefType>().getRank();
return rank > 0 ? b->create<LoadOp>(loc, memref, indices)
: b->create<LoadOp>(loc, memref);
}
DenseSet<Operation*> NoLoaderUser(SmallVectorImpl<Operation*>& ops) {
SmallVector<Operation*, 4> worklist;
DenseSet<Operation*> hasLoaderOps;
for (Operation* op : ops) {
Value memref = cast<LmhloOp>(op).getResultBuffer();
if (memref == nullptr) continue;
for (auto* user : memref.getUsers()) {
if (isa<memref::LoadOp>(user)) {
worklist.push_back(op);
hasLoaderOps.insert(op);
}
}
}
while (!worklist.empty()) {
Operation* op = worklist.pop_back_val();
int numOperands = op->getNumOperands();
for (int i = 0; i < numOperands - 1; ++i) {
Value memref = op->getOperand(i);
for (Operation* user : memref.getUsers()) {
if ((!isa<LmhloOp>(user)) || hasLoaderOps.count(user)) continue;
if (cast<LmhloOp>(user).getResultBuffer() == memref) {
worklist.push_back(user);
hasLoaderOps.insert(user);
}
}
}
}
DenseSet<Operation*> noLoaderOps;
for (Operation* op : ops)
if (!hasLoaderOps.count(op)) noLoaderOps.insert(op);
return noLoaderOps;
}
void cleanUnusedLhloOps(Block* parent) {
SmallVector<Operation*, 4> lhloOps;
for (Operation& op : parent->getOperations()) {
if (op.getDialect() == op.getContext()->getLoadedDialect("lmhlo") &&
(!isa<lmhlo::TerminatorOp>(op)))
lhloOps.push_back(&op);
}
const DenseSet<Operation*>& noLoaderUser = NoLoaderUser(lhloOps);
for (auto* lhloOp : noLoaderUser) lhloOp->erase();
}
template <typename LHLO_OpTy>
Value elementalLower(OpBuilder* b, Location loc, LHLO_OpTy op,
ValueRange outputIndex, bool checkCache);
template <>
Value elementalLower<lmhlo::RealDynamicSliceOp>(OpBuilder* b, Location loc,
lmhlo::RealDynamicSliceOp op,
ValueRange outputIndex,
bool checkCache) {
Value startIndicesMemref = op->getOperand(1);
Value stridesMemref = op->getOperand(3);
int rank = outputIndex.size();
SmallVector<Value, 4> inputIndex;
for (int dim = 0; dim < rank; ++dim) {
SmallVector<Value, 4> dimIndex;
dimIndex.push_back(b->create<arith::ConstantOp>(
loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), dim)));
auto startIndexLoad =
b->create<LoadOp>(loc, startIndicesMemref, ValueRange{dimIndex});
auto startIndex =
b->create<arith::IndexCastOp>(loc, b->getIndexType(), startIndexLoad);
auto strideLoad =
b->create<LoadOp>(loc, stridesMemref, ValueRange{dimIndex});
auto stride =
b->create<arith::IndexCastOp>(loc, b->getIndexType(), strideLoad);
// input_dim = out_dim * stride + start_index
auto inputDim = b->create<arith::AddIOp>(
loc, b->create<arith::MulIOp>(loc, outputIndex[dim], stride),
startIndex);
inputIndex.push_back(inputDim);
}
Value operandMemref = *(op->getOperands().begin());
if (!checkCache) return b->create<LoadOp>(loc, operandMemref, inputIndex);
return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex,
b->saveInsertionPoint());
}
namespace {
template <typename T>
Value elementalLowerImplForBroadcastInDimOps(OpBuilder* b, Location loc,
T broadcastInDim,
ValueRange outputIndex,
bool checkCache) {
auto broadcastDimensions =
broadcastInDim.broadcast_dimensions().template getValues<int64_t>();
int outRank = outputIndex.size();
Value operandMemref = broadcastInDim->getOperand(0);
SmallVector<Value, 4> inputIndex;
for (int64_t dim = 0; dim < outRank; ++dim) {
auto it =
std::find(broadcastDimensions.begin(), broadcastDimensions.end(), dim);
bool isBroadcastDim = (it != broadcastDimensions.end());
if (isBroadcastDim) {
int inputDim = std::distance(broadcastDimensions.begin(), it);
int64_t staticDimSize =
operandMemref.getType().cast<MemRefType>().getShape()[inputDim];
if (staticDimSize == 1) {
// we know this dim is to be broadcasted at compile time
auto zero = b->create<arith::ConstantOp>(
loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
inputIndex.push_back(zero);
} else if (staticDimSize == ShapedType::kDynamicSize) {
// we are not sure if this dim is to be broadcasted at compile time
auto dimSize = b->create<DimOp>(loc, operandMemref, inputDim);
auto one = b->create<arith::ConstantOp>(
loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 1));
auto zero = b->create<arith::ConstantOp>(
loc, b->getIndexType(), b->getIntegerAttr(b->getIndexType(), 0));
auto dimSizeIs1 = b->create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, dimSize, one);
inputIndex.push_back(b->create<mlir::arith::SelectOp>(
loc, dimSizeIs1, zero, outputIndex[dim]));
} else {
// we know this dim is not to be broadcasted at compile time
inputIndex.push_back(outputIndex[dim]);
}
}
}
if (!checkCache) {
int rank = operandMemref.getType().dyn_cast<MemRefType>().getRank();
return (rank > 0) ? b->create<LoadOp>(loc, operandMemref, inputIndex)
: b->create<LoadOp>(loc, operandMemref, ValueRange());
}
return createLoadOrUseCachedValue(loc, b, operandMemref, inputIndex,
b->saveInsertionPoint());
}
} // namespace
template <>
Value elementalLower<lmhlo::DynamicBroadcastInDimOp>(
OpBuilder* b, Location loc, lmhlo::DynamicBroadcastInDimOp op,
ValueRange outputIndex, bool checkCache) {
return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex,
checkCache);
}
template <>
Value elementalLower<lmhlo::BroadcastInDimOp>(OpBuilder* b, Location loc,
lmhlo::BroadcastInDimOp op,
ValueRange outputIndex,
bool checkCache) {
return elementalLowerImplForBroadcastInDimOps(b, loc, op, outputIndex,
checkCache);
}
scf::ForOp createLoopAndSetInsPt(OpBuilder& b, Location loc, Value& var,
Value lb, Value ub, Value step,
ArrayRef<Value> initValues) {
auto forOp = b.create<scf::ForOp>(loc, lb, ub, step, initValues);
b.setInsertionPointToStart(forOp.getBody());
var = forOp.getInductionVar();
return forOp;
}
scf::ParallelOp createParallelAndSetInsPt(OpBuilder& b, Location loc,
SmallVectorImpl<Value>& vars,
ArrayRef<Value> lbs,
ArrayRef<Value> ubs,
ArrayRef<Value> steps,
ArrayRef<Value> initValues) {
auto parOp = b.create<scf::ParallelOp>(loc, lbs, ubs, steps, initValues,
/*bodyBuilderFn=*/nullptr);
b.setInsertionPointToStart(parOp.getBody());
vars.append(parOp.getInductionVars().begin(), parOp.getInductionVars().end());
return parOp;
}
// reinterpret_cast the input memref into 1D
memref::ReinterpretCastOp createMemRef1DReinterpretCast(OpBuilder& b,
Location loc,
Value memref) {
auto memrefTy = memref.getType().cast<MemRefType>();
assert(memrefTy.getLayout().isIdentity());
Value size = codegen_utils::emitNumElementsComputation(b, loc, memref);
Value stride = b.create<mlir::arith::ConstantOp>(
loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 1));
Value zero = b.create<mlir::arith::ConstantOp>(
loc, b.getIndexType(), b.getIntegerAttr(b.getIndexType(), 0));
auto memref1dType =
MemRefType::get({ShapedType::kDynamicSize}, memrefTy.getElementType(),
b.getMultiDimIdentityMap(1), memrefTy.getMemorySpace());
return b.create<memref::ReinterpretCastOp>(
loc, memref1dType, memref, zero, ValueRange{size}, ValueRange{stride});
}
void createOffsetStore(OpBuilder& b, Location loc, Value res, Value memref,
Value offset) {
Value memref1d = createMemRef1DReinterpretCast(b, loc, memref);
b.create<memref::StoreOp>(loc, res, memref1d, ValueRange{offset});
}
memref::LoadOp createOffsetLoad(OpBuilder& b, Location loc, Value memref,
Value offset) {
Value memref1d = createMemRef1DReinterpretCast(b, loc, memref);
return b.create<memref::LoadOp>(loc, memref1d, ValueRange{offset});
}
} // namespace lmhlo
} // namespace mlir