blob: 82aa67fc1754109cd8e7d42e33c36469eb67b225 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include <algorithm>
#include <numeric>
#include <string>
#include <utility>
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir-hlo/Dialect/mhlo/transforms/type_conversion.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace {
template <typename OpTy>
SmallVector<NamedAttribute> PruneAttributeList(OpTy op) {
auto op_attributes = op.getAttributeNames();
llvm::StringSet<> elided_attrs;
elided_attrs.insert(op_attributes.begin(), op_attributes.end());
SmallVector<NamedAttribute> preserved_attrs;
for (auto attr : op->getAttrs()) {
if (elided_attrs.count(attr.getName())) continue;
preserved_attrs.push_back(attr);
}
return preserved_attrs;
}
/// Returns an ArrayAttr that contains `nLoops` attributes. All the attributes
/// are "parallel" except the last `nReduction` elements, where are "reduction"
/// attributes.
SmallVector<StringRef, 3> GetParallelAndReductionIterators(
unsigned nLoops, unsigned nReduction) {
SmallVector<StringRef, 3> res(nLoops - nReduction,
getParallelIteratorTypeName());
res.append(nReduction, getReductionIteratorTypeName());
return res;
}
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
return GetParallelAndReductionIterators(nParallelLoops, 0);
}
Value GetResultValue(Operation* op) { return op->getResult(0); }
ShapedType GetHloOpResultType(Operation* op) {
return GetResultValue(op).getType().cast<ShapedType>();
}
bool VerifyHloOpBufferOrTensorSemantics(Operation* op) {
auto verify_type = [&](Value val) -> bool {
return val.getType().isa<RankedTensorType>();
};
if (!llvm::all_of(op->getOperands(), verify_type)) return false;
return llvm::all_of(op->getResults(), verify_type);
}
Value GetInitTensor(OpBuilder& b, Location loc, ShapedType type,
ArrayRef<Value> dyn_sizes) {
return b.create<linalg::InitTensorOp>(loc, dyn_sizes, type.getShape(),
type.getElementType());
}
Value GetInitSparseTensor(OpBuilder& b, Location loc, ShapedType type,
ArrayRef<Value> sizes) {
return b.create<sparse_tensor::InitOp>(loc, type, sizes);
}
Value GetInitTensorFor(OpBuilder& b, Location loc, ShapedType result_type,
Operation* op, ValueRange operands) {
bool is_sparse =
sparse_tensor::getSparseTensorEncoding(result_type) != nullptr;
// Collect the sizes for a ranked tensor to be passed as parameter to a
// new tensor initialization operation. This operation only needs the
// dynamic size in the dense case, but all sizes when the tensor is sparse.
SmallVector<Value> sizes;
if (result_type.hasRank() && (is_sparse || !result_type.hasStaticShape())) {
// Ask the op for its output shape.
auto shape_source = cast<InferShapedTypeOpInterface>(op);
SmallVector<Value, 1> reified_shapes;
(void)shape_source.reifyReturnTypeShapes(b, operands, reified_shapes);
assert(reified_shapes.size() == 1 && "Expected one reified result");
// Construct sizes for the required dimensions.
for (auto& en : llvm::enumerate(result_type.getShape())) {
if (en.value() != ShapedType::kDynamicSize) {
if (is_sparse)
sizes.push_back(b.create<arith::ConstantIndexOp>(loc, en.value()));
continue;
}
sizes.push_back(b.create<tensor::ExtractOp>(
loc, reified_shapes[0],
ValueRange{b.create<arith::ConstantIndexOp>(loc, en.index())}));
}
}
return is_sparse ? GetInitSparseTensor(b, loc, result_type, sizes)
: GetInitTensor(b, loc, result_type, sizes);
}
SmallVector<int64_t, 4> Extract1DVector(DenseIntElementsAttr elements) {
SmallVector<int64_t, 4> ret;
for (const APInt& element : elements) {
ret.push_back(element.getLimitedValue());
}
return ret;
}
/// Returns the constant value associated with the init value if the defining
/// operation is a constant.
Attribute GetInitValueAsConst(Value init) {
DenseElementsAttr attr;
if (!matchPattern(init, m_Constant(&attr))) return {};
auto type = attr.getType().dyn_cast<ShapedType>();
if (!type || type.getRank() != 0) return {};
return attr.getValues<Attribute>()[0];
}
/// Returns a permutation AffineMap that puts all reduction dimensions to the
/// last. The order of parallel loops and reduction loops are all sorted. E.g.,
/// if `rank` is 4 and `reductionDims` is {1, 3}, then
/// "(d0, d1, d2, d3) -> (d0, d2, d1, d3)" is used. The inverse permutation of
/// the AffineMap is returned.
AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
ArrayRef<int64_t> reduction_dims) {
llvm::SmallSetVector<int, 4> s;
for (auto dim : reduction_dims) s.insert(dim);
SmallVector<unsigned, 4> permutation;
for (int i = 0; i < rank; ++i)
if (!s.count(i)) permutation.push_back(i);
for (auto dim : reduction_dims) permutation.push_back(dim);
auto map = AffineMap::getPermutationMap(permutation, context);
return inversePermutation(map);
}
/// Returns true if the given `attr` is a splat of the given `value`.
bool isSplatValue(DenseIntElementsAttr attr, uint64_t value) {
return attr.isSplat() && attr.getSplatValue<uint64_t>() == value;
}
/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
/// follows a canonical form:
///
/// * Input dimensions have order: (batch_count, spatial_dims,
/// input_channel_count).
/// * Filter dimensions have order: (spatial_dims, input_channel_count,
/// output_channel_count).
/// * Output dimensions have order: (batch_count, spatial_dims,
/// output_channel_count).
static bool HasCanonicalDimensionNumbers(
mhlo::ConvDimensionNumbersAttr dimension_numbers) {
const int input_spatial_rank =
llvm::size(dimension_numbers.getInputSpatialDimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimension_numbers.getInputBatchDimension() != 0 ||
dimension_numbers.getInputFeatureDimension() !=
(input_spatial_rank + 1)) {
return false;
}
const int kernel_spatial_rank =
llvm::size(dimension_numbers.getKernelSpatialDimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimension_numbers.getKernelInputFeatureDimension() !=
kernel_spatial_rank ||
dimension_numbers.getKernelOutputFeatureDimension() !=
(kernel_spatial_rank + 1)) {
return false;
}
const int output_spatial_rank =
llvm::size(dimension_numbers.getOutputSpatialDimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimension_numbers.getOutputBatchDimension() != 0 ||
dimension_numbers.getOutputFeatureDimension() !=
(output_spatial_rank + 1)) {
return false;
}
if (input_spatial_rank != output_spatial_rank ||
input_spatial_rank != kernel_spatial_rank) {
return false;
}
const auto* input_spatial_dim =
dimension_numbers.getInputSpatialDimensions().begin();
const auto* kernel_spatial_dim =
dimension_numbers.getKernelSpatialDimensions().begin();
const auto* output_spatial_dim =
dimension_numbers.getOutputSpatialDimensions().begin();
// Check spatial dims are ordered correctly.
for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1;
if ((*input_spatial_dim++) != dim || (*output_spatial_dim++) != dim ||
(*kernel_spatial_dim++) != i) {
return false;
}
}
return true;
}
//===----------------------------------------------------------------------===//
// mhlo.RngUniformOp conversion patterns.
//===----------------------------------------------------------------------===//
// Pass to lower from rng_uniform to stateless uniform pseudo RNG with LCG
// algorithm
struct RngUniformConversion : public OpConversionPattern<mhlo::RngUniformOp> {
using OpConversionPattern<mhlo::RngUniformOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::RngUniformOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
// TODO(raikonenfnu): Handle other element types as well.
auto min_ty = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
auto max_ty = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
if (!min_ty.getElementType().dyn_cast<FloatType>() ||
!max_ty.getElementType().dyn_cast<FloatType>()) {
return rewriter.notifyMatchFailure(
op, "expected min/max for rng op to be FloatType");
}
auto target_ty = this->typeConverter->convertType(op.getResult().getType())
.cast<ShapedType>();
if (!target_ty) {
return rewriter.notifyMatchFailure(
op, "expected target shape of rng op to be ShapedType");
}
auto loc = op.getLoc();
Value init_tensor =
GetInitTensorFor(rewriter, loc, target_ty, op, adaptor.getOperands());
// Creates index map using target matrix's rank.
auto target_rank = target_ty.getRank();
SmallVector<AffineMap, 3> indexing_maps(
2, AffineMap::get(target_rank, /*symbolCount=*/0,
SmallVector<AffineExpr>({}), rewriter.getContext()));
indexing_maps.push_back(rewriter.getMultiDimIdentityMap(target_rank));
const int kInitialSeed = 0;
// Generic region with LCG Algorithm that make use of element index from:
// https://reviews.llvm.org/D101364
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensors=*/target_ty,
/*inputs=*/
ValueRange{adaptor.getOperands()[0], adaptor.getOperands()[1]},
/*outputs=*/init_tensor, indexing_maps,
GetParallelAndReductionIterators(/*nLoops=*/target_rank,
/*nReduction=*/0),
[&](OpBuilder& b, Location loc, ValueRange args) {
llvm::SmallVector<Value> update_vec = {b.create<arith::ConstantOp>(
loc, b.getI32IntegerAttr(kInitialSeed))};
Value multiplier =
b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(1103515245));
Value incrementStep =
b.create<arith::ConstantOp>(loc, b.getI32IntegerAttr(12345));
// For output matrix with rank N:
// temp1 = (cast(I32, index(D.0)) + seed) * mult + incr
// ...
// tempN = (cast(I32, index(D.(N))) + tempN_1) * mult + incr
for (int i = 0; i < target_rank; i++) {
Value update = update_vec.back();
Value ind = b.create<linalg::IndexOp>(loc, i);
Value cast_ind =
b.create<arith::IndexCastOp>(loc, b.getI32Type(), ind);
Value add_res = b.create<arith::AddIOp>(loc, cast_ind, update);
Value mult_res = b.create<arith::MulIOp>(loc, add_res, multiplier);
Value inc_res =
b.create<arith::AddIOp>(loc, mult_res, incrementStep);
update_vec.push_back(inc_res);
}
// Scaling = (max - min) * const(F64, 2.3283064E-10)
// which is derived from rand(min,max) = rand()/(RAND_MAX/(max-min)).
Value epsilon = b.create<arith::ConstantOp>(
loc, b.getFloatAttr(args[0].getType(), 2.3283064E-10));
Value range = b.create<arith::SubFOp>(loc, args[1], args[0]);
Value scale = b.create<arith::MulFOp>(loc, range, epsilon);
// Res = cast(T, cast(F64, tempN) * scaling + min)
Value update_cast = b.create<arith::UIToFPOp>(
loc, target_ty.getElementType(), update_vec.back());
Value scale_update = b.create<arith::MulFOp>(loc, update_cast, scale);
Value res = b.create<arith::AddFOp>(loc, scale_update, args[0]);
b.create<linalg::YieldOp>(loc, res);
},
PruneAttributeList(op));
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
//===----------------------------------------------------------------------===//
// mhlo.Einsum conversion patterns.
//===----------------------------------------------------------------------===//
// Looks through a set of dimension that has been marked as reduction axes,
// if it is found within the set, then we set it as "reduction", otherwise
// we can label it as "parallel".
SmallVector<StringRef, 3> GetEinsumLoopsAttrs(
const llvm::SmallSetVector<StringRef, 4>& input_ind,
const llvm::SmallSetVector<StringRef, 4>& reduction_dims) {
SmallVector<StringRef, 3> res;
for (StringRef dim : input_ind) {
if (!reduction_dims.contains(dim)) {
res.push_back(getParallelIteratorTypeName());
} else {
res.push_back(getReductionIteratorTypeName());
}
}
return res;
}
SmallVector<Value, 2> ExtractDynamicEinsumSizes(
OpBuilder& b, Location loc, Value lhs, Value rhs,
const SmallVector<std::string>& lhs_loop_vec,
const SmallVector<std::string>& rhs_loop_vec,
const SmallVector<std::string>& output_loop_vec) {
SmallVector<Value, 2> dyn_sizes;
for (const std::string& dim_ind : output_loop_vec) {
Value dim_size;
const auto* dim_ind_it =
std::find(lhs_loop_vec.begin(), lhs_loop_vec.end(), dim_ind);
if (dim_ind_it != lhs_loop_vec.end()) {
// Query from lhs vars.
auto dim_ind_pos = dim_ind_it - lhs_loop_vec.begin();
auto lhs_shape = lhs.getType().dyn_cast<RankedTensorType>().getShape();
if (lhs_shape[dim_ind_pos] != ShapedType::kDynamicSize) continue;
dim_size = b.create<tensor::DimOp>(loc, lhs, dim_ind_pos);
} else {
// query from rhs vars.
dim_ind_it = std::find(rhs_loop_vec.begin(), rhs_loop_vec.end(), dim_ind);
auto dim_ind_pos = dim_ind_it - rhs_loop_vec.begin();
auto rhs_shape = rhs.getType().dyn_cast<RankedTensorType>().getShape();
if (rhs_shape[dim_ind_pos] != ShapedType::kDynamicSize) continue;
dim_size = b.create<tensor::DimOp>(loc, rhs, dim_ind_pos);
}
dyn_sizes.push_back(dim_size);
}
return dyn_sizes;
}
// Adds indices/axes that are missing from output set.
llvm::SmallSetVector<StringRef, 4> FindSummationAxes(
const llvm::SmallSetVector<StringRef, 4>& input_set,
const llvm::SmallSetVector<StringRef, 4>& output_set) {
llvm::SmallSetVector<StringRef, 4> summation_axes;
for (StringRef ind : input_set) {
if (!output_set.contains(ind)) summation_axes.insert(ind);
}
return summation_axes;
}
// Given a 1:1 map from std::string -> affine dimension expression
// we can get the affine expression of dimensions that an
// operand will access based on the input_str of einsum_config.
// For example:
// let string_dim_umap = {'a' : d0, 'b' : d1, 'c' : d2}
// for einsum_config "abc,cb->acb"
// first_input_operand will get umap[{"a","b","c"}] -> (d0, d1, d2).
// second_input_operand will get umap[{"c","b"}] -> (d2, d1).
// ouput_operand will get umap[{"a","c","b"}] -> (d0, d2, d1).
SmallVector<AffineExpr> GetExprFromConfig(
const SmallVector<std::string>& loop_dims,
const DenseMap<StringRef, AffineExpr>& str_affine_dim_umap) {
SmallVector<AffineExpr> exprs;
for (const auto& dim : loop_dims) {
exprs.push_back(str_affine_dim_umap.lookup(dim));
}
return exprs;
}
// Convert mhlo.einsum op into linalg.generic.
// Algorithm in general 3 steps:
// Step1) Dissect entire einsum_config to different operands
// e.g f("abc,cd->abd") = {lhs:["abc"], rhs:["cd"], out:["abd"]}.
// Step2) Split up the string into vector of the elements
// e.g {lhs:["abc"], rhs:["cd"], out:["abd"]} = {lhs:["a","b","c"],
// rhs:["c","d"], out:["a","b","d"]}.
// Step3) Convert the vector into data access
// patern represented by affineMaps with affineDimensions e.g
// {lhs:["a","b","c"], rhs:["c","d"], out:["a","b","d"]} = {lhs:[d0,d1,d2],
// rhs:[d2,d3], out:[d0,d1,d3]}.
class EinsumToLinalgConverter : public OpConversionPattern<mhlo::EinsumOp> {
public:
using OpConversionPattern<mhlo::EinsumOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::EinsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto get_rank = [](Value v) {
return v.getType().cast<ShapedType>().getRank();
};
auto einsum_config = op.einsum_config();
// With the assumption of binary input operand and single output
// get the inputs and output operands' indices.
// einsum_config = "lhs_loop,rhs_loop->out_loop"
std::size_t pos_arrow = einsum_config.find(kArrow);
std::size_t pos_comma = einsum_config.find(kComma);
StringRef lhs_loop = einsum_config.substr(0, pos_comma);
StringRef rhs_loop = einsum_config.substr(
pos_comma + kComma.size(), pos_arrow - (pos_comma + kComma.size()));
StringRef out_loop = einsum_config.substr(pos_arrow + kArrow.size());
// Check for Invalid Configs.
// 1.Check that there is only maximum 2 inputs
// 2.Check that there is only maximum 1 output
// 3.Check that there is 1 kArrow
if (rhs_loop.find(kComma) != std::string::npos ||
out_loop.find(kComma) != std::string::npos ||
out_loop.find(kArrow) != std::string::npos) {
return rewriter.notifyMatchFailure(op, "Invalid einsum config!");
}
// Find result type, if on tensors.
auto result_ty = this->typeConverter->convertType(GetHloOpResultType(op))
.dyn_cast<RankedTensorType>();
// Check result type compatibility.
if (!result_ty || !(result_ty.getElementType().isSignlessIntOrFloat())) {
return rewriter.notifyMatchFailure(op, "Invalid result type");
}
// Convert the representation to vector<string>.
SmallVector<std::string> lhs_ein =
GetEinsumConfigAsVector(lhs_loop, get_rank(adaptor.lhs()));
SmallVector<std::string> rhs_ein =
GetEinsumConfigAsVector(rhs_loop, get_rank(adaptor.rhs()));
SmallVector<std::string> out_ein =
GetEinsumConfigAsVector(out_loop, result_ty.getRank());
if (!CheckBatchHasEqualRank(lhs_ein.size(), lhs_loop, rhs_ein.size(),
rhs_loop, out_ein.size(), out_loop)) {
return rewriter.notifyMatchFailure(
op, "Invalid elipsis('...') within einsum config!");
}
// Find all unique indices in the input and output.
llvm::SmallSetVector<StringRef, 4> input_ind;
llvm::SmallSetVector<StringRef, 4> output_ind;
input_ind.insert(lhs_ein.begin(), lhs_ein.end());
input_ind.insert(rhs_ein.begin(), rhs_ein.end());
output_ind.insert(out_ein.begin(), out_ein.end());
llvm::SmallSetVector<StringRef, 4> reduction_axe =
FindSummationAxes(input_ind, output_ind);
// Find input/output values and types.
auto loc = op.getLoc();
// Prepare init tensor for linalg.generic op.
auto dyn_sizes = ExtractDynamicEinsumSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), lhs_ein, rhs_ein, out_ein);
Value output = GetInitTensor(rewriter, loc, result_ty, dyn_sizes);
if (!reduction_axe.empty()) {
auto zero_attr = rewriter.getZeroAttr(result_ty.getElementType());
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
output = rewriter.create<linalg::FillOp>(loc, zero, output).getResult(0);
}
// Create indexing maps.
// Create a 1:1 map from f:strDimension -> affineDimension.
int64_t nloops = input_ind.size();
DenseMap<StringRef, AffineExpr> str_affine_dim_umap;
for (auto& it : llvm::enumerate(input_ind)) {
str_affine_dim_umap[it.value()] = rewriter.getAffineDimExpr(it.index());
}
// From einsum_config of each operand in vector<string>, generate
// the equivalent vector<AffineExpr>.
SmallVector<AffineMap, 4> maps;
for (const SmallVector<std::string>& loop_operand :
{lhs_ein, rhs_ein, out_ein}) {
auto exprs = GetExprFromConfig(loop_operand, str_affine_dim_umap);
maps.push_back(AffineMap::get(nloops, 0, exprs, rewriter.getContext()));
}
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, result_ty ? result_ty : TypeRange{}, adaptor.getOperands(), output,
maps, GetEinsumLoopsAttrs(input_ind, reduction_axe),
[&](OpBuilder& b, Location nested_loc, ValueRange args) {
Value result_val =
b.create<mlir::arith::MulFOp>(nested_loc, args[0], args[1]);
if (!reduction_axe.empty()) {
result_val =
b.create<mlir::arith::AddFOp>(nested_loc, args[2], result_val);
}
b.create<linalg::YieldOp>(nested_loc, result_val);
},
PruneAttributeList(op));
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
private:
static constexpr StringRef kArrow = "->";
static constexpr StringRef kComma = ",";
static constexpr StringRef kEllipsis = "...";
static bool CheckBatchHasEqualRank(size_t lhs_rank, StringRef lhs_loop,
size_t rhs_rank, StringRef rhs_loop,
size_t out_rank, StringRef out_loop);
static SmallVector<std::string> GetEinsumConfigAsVector(StringRef loop,
size_t operand_rank);
};
// Definition of util const member variables.
constexpr StringRef EinsumToLinalgConverter::kArrow;
constexpr StringRef EinsumToLinalgConverter::kComma;
constexpr StringRef EinsumToLinalgConverter::kEllipsis;
// Convert the representation from string/vector<char> to vector<string>.
// i.e ("abc") -> {"a", "b", "c"}. For cases with ellipsis with batch rank 3:
// get loop_dim = f("ab...cde") = {"a","b","0","1","2","c","d","e"}
SmallVector<std::string> EinsumToLinalgConverter::GetEinsumConfigAsVector(
StringRef loop, size_t operand_rank) {
SmallVector<std::string> loop_dim;
size_t pre_elip = loop.find(kEllipsis);
bool has_elip = pre_elip != std::string::npos;
if (!has_elip) pre_elip = loop.size();
// Add the dimension until the end or up to ellipsis if it exist.
for (int pre_elip_ind = 0; pre_elip_ind < pre_elip; pre_elip_ind++) {
loop_dim.push_back(loop.substr(pre_elip_ind, 1).str());
}
if (!has_elip) return loop_dim;
// Case where Ellipsis presence:
size_t non_batch_rank = loop.size() - kEllipsis.size();
size_t batch_rank = operand_rank - non_batch_rank;
// Add the batch dimension ("0",...,"N") where N is rank of batch into the
// loop.
for (int batch_ind = 0; batch_ind < batch_rank; batch_ind++) {
loop_dim.push_back(std::to_string(batch_ind));
}
// Add the dimension after ellipsis into the loop.
int post_elip = pre_elip + kEllipsis.size();
for (int post_elip_ind = post_elip; post_elip_ind < loop.size();
post_elip_ind++) {
loop_dim.push_back(loop.substr(post_elip_ind, 1).str());
}
return loop_dim;
}
// Returns true if all operand's batch has same rank.
bool EinsumToLinalgConverter::CheckBatchHasEqualRank(
size_t lhs_rank, StringRef lhs_loop, size_t rhs_rank, StringRef rhs_loop,
size_t out_rank, StringRef out_loop) {
SmallVector<int, 3> batch_rank_vec;
if (lhs_rank != lhs_loop.size()) {
size_t lhs_batch_rank = lhs_rank - (lhs_loop.size() - kEllipsis.size());
batch_rank_vec.push_back(lhs_batch_rank);
}
if (rhs_rank != rhs_loop.size()) {
size_t rhs_batch_rank = rhs_rank - (rhs_loop.size() - kEllipsis.size());
batch_rank_vec.push_back(rhs_batch_rank);
}
if (out_rank != out_loop.size()) {
size_t out_batch_rank = out_rank - (out_loop.size() - kEllipsis.size());
batch_rank_vec.push_back(out_batch_rank);
}
bool batch_has_equal_rank = true;
// Condition is valid if only 1 operand or less have batches.
if (batch_rank_vec.size() < 2) return batch_has_equal_rank;
if (!std::equal(batch_rank_vec.begin() + 1, batch_rank_vec.end(),
batch_rank_vec.begin()) &&
batch_rank_vec.size() > 1)
batch_has_equal_rank = false;
return batch_has_equal_rank;
}
template <typename OpTy>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
// Find maximum rank / number of loops.
auto get_rank = [](Value v) {
return v.getType().cast<ShapedType>().getRank();
};
auto is_scalar = [&](Value v) { return get_rank(v) == 0; };
auto it = llvm::find_if_not(adaptor.getOperands(), is_scalar);
Value max_rank_arg =
it != adaptor.getOperands().end() ? *it : adaptor.getOperands().front();
int64_t nloops = get_rank(max_rank_arg);
// Apply only if all operands are scalar or have the same rank. Some ops,
// like `mhlo.select`, support implicit broadcasting of scalars.
if (!llvm::all_of(adaptor.getOperands(), [&](Value v) {
int64_t r = get_rank(v);
return r == 0 || r == nloops;
})) {
return rewriter.notifyMatchFailure(
op, "Operands must be os same rank or scalar.");
}
// Find result type, if on tensors.
Optional<ShapedType> result_ty;
result_ty = this->typeConverter->convertType(op->getResultTypes().front())
.template dyn_cast<ShapedType>();
// Check result type compatibility.
if (!result_ty || !result_ty->hasRank() || result_ty->getRank() != nloops ||
!(result_ty->getElementType().isSignlessIntOrFloat() ||
result_ty->getElementType().isa<ComplexType>())) {
return rewriter.notifyMatchFailure(
op, "mismatched operand/result types or iterator count");
}
// Find input/output values and types.
auto loc = op.getLoc();
ValueRange inputs = adaptor.getOperands();
Value output =
GetInitTensorFor(rewriter, loc, *result_ty, op, adaptor.getOperands());
// Create indexing maps.
AffineMap scalar_map = AffineMap::get(nloops, 0, rewriter.getContext());
AffineMap id_map = rewriter.getMultiDimIdentityMap(nloops);
SmallVector<AffineMap, 4> maps;
for (Value v : inputs) maps.push_back(is_scalar(v) ? scalar_map : id_map);
maps.push_back(id_map);
// Build `linalg.generic` op.
bool failed = false;
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, result_ty ? *result_ty : TypeRange{}, inputs, output, maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location /*nested_loc*/,
ValueRange args) {
Type inner_result_ty = getElementTypeOrSelf(output);
Value inner_result = mhlo::MhloOpToStdScalarOp::map<OpTy>(
op, inner_result_ty,
llvm::to_vector<2>(args.take_front(inputs.size())), &rewriter);
if (inner_result == nullptr) {
failed = true;
} else {
nested_builder.create<linalg::YieldOp>(loc, inner_result);
}
},
PruneAttributeList(op));
if (failed) return failure();
rewriter.replaceOp(op, linalg_op->getResults());
return success();
}
};
template <typename MhloOp>
class ScalarPointwiseToStandardConverter : public OpConversionPattern<MhloOp> {
public:
using OpConversionPattern<MhloOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
MhloOp mhlo_op, ConversionPatternRewriter& rewriter) const final {
auto loc = mhlo_op.getLoc();
auto arg_type =
mhlo_op.getOperand(0).getType().template dyn_cast<ShapedType>();
if (!arg_type || !arg_type.getElementType().isSignlessIntOrFloat() ||
(arg_type.getRank() != 0)) {
return failure();
}
// Create two loads from the input.
auto lhs = rewriter.create<memref::LoadOp>(loc, mhlo_op.lhs());
auto rhs = rewriter.create<memref::LoadOp>(loc, mhlo_op.rhs());
Value op_result = mhlo::MhloOpToStdScalarOp::map<MhloOp>(
mhlo_op, arg_type.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<memref::StoreOp>(loc, op_result, mhlo_op.out());
rewriter.eraseOp(mhlo_op);
return success();
}
};
/// Base class for lowering HLO operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
/// and the output.
template <typename Derived, typename OpTy>
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics(op)) return failure();
auto result_type = GetHloOpResultType(op);
result_type = this->typeConverter->convertType(result_type)
.template cast<ShapedType>();
SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure();
auto nloops = result_type.getRank();
auto loc = op.getLoc();
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc,
/*resultTensorTypes=*/result_type,
/*inputs=*/adaptor.getOperands().front(),
/*outputBuffers=*/
ValueRange{GetInitTensorFor(rewriter, loc, result_type, op,
adaptor.getOperands())},
indexing_maps, GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location /*nested_loc*/,
ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
},
PruneAttributeList(op));
rewriter.replaceOp(op, linalg_op.getOperation()->getResults());
return success();
}
};
/// Pattern to convert BroadcastOp to Linalg ops.
template <typename OpTy>
class BroadcastConverter
: public DataMovementOpConverter<BroadcastConverter<OpTy>, OpTy> {
public:
using DataMovementOpConverter<BroadcastConverter,
OpTy>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcast_op,
Builder* b) {
ShapedType input_type =
broadcast_op.operand().getType().template cast<ShapedType>();
unsigned input_rank = input_type.getRank();
unsigned nloops = GetHloOpResultType(broadcast_op).getRank();
// BroadcastOp prepends the dimensions in the `broadcast_sizes` attribute to
// the input's dimensions.
unsigned num_prepended_dims = llvm::size(broadcast_op.broadcast_sizes());
SmallVector<AffineExpr, 4> input_dim_exprs;
input_dim_exprs.reserve(input_rank);
for (unsigned i = 0; i < input_rank; ++i) {
input_dim_exprs.push_back(b->getAffineDimExpr(num_prepended_dims + i));
}
AffineMap input_map;
MLIRContext* context = b->getContext();
if (input_dim_exprs.empty()) {
// The input is a scalar, i.e. this is a scalar broadcast op.
input_map = AffineMap::get(nloops, /*symbolCount=*/0, context);
} else {
input_map =
AffineMap::get(nloops, /*symbolCount=*/0, input_dim_exprs, context);
}
return {input_map, b->getMultiDimIdentityMap(nloops)};
}
};
class HloBroadcastInDimConverter
: public DataMovementOpConverter<HloBroadcastInDimConverter,
mhlo::BroadcastInDimOp> {
public:
using DataMovementOpConverter<
HloBroadcastInDimConverter,
mhlo::BroadcastInDimOp>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(
mhlo::BroadcastInDimOp broadcast_op, Builder* b) {
auto result_type = GetHloOpResultType(broadcast_op);
auto operand_type =
broadcast_op.operand().getType().template cast<ShapedType>();
unsigned nloops = result_type.getRank();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operand_type.getRank() == 0) {
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operand_shape = operand_type.getShape();
SmallVector<AffineExpr, 4> dim_exprs;
dim_exprs.reserve(nloops);
if (broadcast_op.broadcast_dimensions()) {
for (const auto& broadcastDim :
enumerate(broadcast_op.broadcast_dimensions().getValues<APInt>())) {
int size = broadcastDim.value().getSExtValue();
bool expansion_needed = operand_shape[broadcastDim.index()] == 1 &&
result_type.getShape()[size] != 1;
dim_exprs.push_back(expansion_needed ? b->getAffineConstantExpr(0)
: b->getAffineDimExpr(size));
}
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, dim_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
// If the input has a static shape we know exactly when the broadcast must
// expand (the dimension is 1, which also trivially expands to 1) or will never
// expand (the dimension is not 1). We can also source the information from the
// optionally provided attrbibutes on statically known broadcasting behavior.
// This means we can lower the broadcast just as we would lower a fully static
// broadcast and go directly to `linalg.generic`.
// This also covers the important case of broadcasting a scalar. Ideally the
// pattern (`mhlo.constant` -> `mhlo.dynamic_broadcast_in_dim`) should be
// converted to a tensor dialect op similar to TF's `ConstantLikeOp`.
class HloDynamicBroadcastInDimConverter
: public OpConversionPattern<mhlo::DynamicBroadcastInDimOp> {
public:
using OpConversionPattern<mhlo::DynamicBroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
Value operand = adaptor.operand();
auto operand_type = operand.getType().dyn_cast<RankedTensorType>();
if (!operand_type) return failure();
auto result_type =
typeConverter->convertType(op.getType()).dyn_cast<RankedTensorType>();
if (!result_type) return failure();
// Determine dimension expressions based on whether the dimension is
// expanding (0) or non-expanding (identity), and fail if we cannot decide
// this.
SmallVector<AffineExpr> dim_exprs(operand_type.getRank(), nullptr);
// Use static type info.
auto bcast_dims =
llvm::to_vector(llvm::map_range(op.broadcast_dimensions(), [](APInt d) {
return static_cast<int64_t>(d.getLimitedValue());
}));
for (const auto& it : llvm::enumerate(operand_type.getShape())) {
if (ShapedType::isDynamic(it.value())) continue;
bool is_expanding = it.value() == 1;
dim_exprs[it.index()] =
is_expanding ? rewriter.getAffineConstantExpr(0)
: rewriter.getAffineDimExpr(bcast_dims[it.index()]);
}
// Use annotated expansion behavior, if available.
if (op.known_expanding_dimensions()) {
for (const auto& it :
op.known_expanding_dimensions()->getValues<APInt>()) {
auto i = it.getLimitedValue();
dim_exprs[i] = rewriter.getAffineConstantExpr(0);
}
}
if (op.known_nonexpanding_dimensions()) {
for (const auto& it :
op.known_nonexpanding_dimensions()->getValues<APInt>()) {
auto i = it.getLimitedValue();
dim_exprs[i] = rewriter.getAffineDimExpr(bcast_dims[i]);
}
}
// Fail if unknown expansion behavior remains.
if (!llvm::all_of(dim_exprs, [](AffineExpr expr) { return expr; }))
return failure();
// Materialize `linalg.generic` op.
Location loc = op.getLoc();
int64_t nloops = result_type.getRank();
Value init =
GetInitTensorFor(rewriter, loc, result_type, op, adaptor.getOperands());
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, TypeRange{init.getType()}, ValueRange{operand},
/*outputBuffers=*/ValueRange{init},
llvm::makeArrayRef(
{AffineMap::get(/*dimCount=*/nloops, /*symbolCount=*/0, dim_exprs,
rewriter.getContext()),
rewriter.getMultiDimIdentityMap(nloops)}),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location /*nested_loc*/,
ValueRange args) {
nested_builder.create<linalg::YieldOp>(loc, *args.begin());
},
PruneAttributeList(op));
return success();
}
};
template <typename OpTy>
class TransposeConverter
: public DataMovementOpConverter<TransposeConverter<OpTy>, OpTy> {
public:
using DataMovementOpConverter<TransposeConverter<OpTy>,
OpTy>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto result_type = GetHloOpResultType(op).template cast<ShapedType>();
auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> input_exprs;
input_exprs.resize(result_type.getRank());
for (const auto& permutation : llvm::enumerate(op.permutation())) {
input_exprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index());
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
/// Lowers mhlo.RealDynamicSliceOp to tensor.extract_slice and other
/// arith/tensor dialect ops.
class RealDynamicSliceConverter
: public OpConversionPattern<mhlo::RealDynamicSliceOp> {
public:
using OpConversionPattern<mhlo::RealDynamicSliceOp>::OpConversionPattern;
/// Computes size of a slice as :-
/// size = ceil((limit - start)/stride)
static Value computeSize(Location loc, Value start, Value limit, Value stride,
ConversionPatternRewriter& rewriter) {
Value delta = rewriter.create<arith::SubIOp>(loc, limit, start);
return rewriter.create<arith::CeilDivUIOp>(loc, delta, stride);
}
LogicalResult matchAndRewrite(
mhlo::RealDynamicSliceOp real_dynamic_slice_op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
Location loc = real_dynamic_slice_op.getLoc();
auto arg_type = adaptor.operand().getType().dyn_cast<ShapedType>();
if (!arg_type || !arg_type.hasRank()) {
return rewriter.notifyMatchFailure(real_dynamic_slice_op,
"require known-rank args");
}
auto result_type =
this->typeConverter->convertType(real_dynamic_slice_op.getType())
.cast<RankedTensorType>();
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getI64IntegerAttr(0));
Type i64_type = rewriter.getI64Type();
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
SmallVector<Type, 3> clamp_type(3, i64_type);
for (auto i : llvm::seq<unsigned>(0, arg_type.getRank())) {
Value dim = rewriter.create<arith::ConstantIndexOp>(loc, i);
Value start =
rewriter.create<tensor::ExtractOp>(loc, adaptor.start_indices(), dim);
Value limit =
rewriter.create<tensor::ExtractOp>(loc, adaptor.limit_indices(), dim);
Value stride =
rewriter.create<tensor::ExtractOp>(loc, adaptor.strides(), dim);
// Compute i-th dimension size of the result : size[i].
// If the i-th dimension of the result type is known, we go ahead with it
// else we compute it using limit, start and stride values.
int64_t result_dim_size = result_type.getDimSize(i);
Value size;
if (ShapedType::isDynamic(result_dim_size)) {
size = computeSize(loc, start, limit, stride, rewriter);
} else {
size = rewriter.create<arith::ConstantIndexOp>(loc, result_dim_size);
}
// Fetch i-th dimension size of the operand and calculate upper bound as
// :-
// ub = operand_dim[i] - size[i]
Value operand_dim_size =
rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(), dim);
Value upper_bound =
rewriter.createOrFold<arith::SubIOp>(loc, operand_dim_size, size);
// We clamp the start_index to keep it bounded as :-
// start index : 0 <= start_index[i] <= ub
// ClampOp lowering does not support index type, so we cast it into
// integer type.
start = rewriter.create<arith::IndexCastOp>(loc, i64_type, start);
upper_bound =
rewriter.create<arith::IndexCastOp>(loc, i64_type, upper_bound);
start = mhlo::MhloOpToStdScalarOp::map<mhlo::ClampOp>(
loc, i64_type, clamp_type, ValueRange{zero, start, upper_bound},
&rewriter);
offsets.push_back(
rewriter
.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), start)
.getResult());
if (ShapedType::isDynamic(result_dim_size)) {
sizes.push_back(size);
} else {
sizes.push_back(rewriter.getI64IntegerAttr(result_dim_size));
}
strides.push_back(stride);
}
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
real_dynamic_slice_op, result_type, adaptor.operand(), offsets, sizes,
strides);
return success();
}
};
// Converts reshape ops that can be proven to be either a collapse of dimensions
// or expansion of dimensions of the operand.
class ReshapeOpConverter : public OpConversionPattern<mhlo::ReshapeOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ReshapeOp reshape_op, mhlo::ReshapeOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics(reshape_op)) return failure();
auto operand = adaptor.operand();
auto operand_type = operand.getType().cast<ShapedType>();
auto elem_type = operand_type.getElementType();
auto result_type = reshape_op.getType().cast<ShapedType>();
if (!result_type.hasStaticShape()) return failure();
result_type = typeConverter->convertType(result_type).cast<ShapedType>();
// Special case where the result is a scalar.
if (result_type.getRank() == 0 && !operand_type.hasStaticShape()) {
// This means all dimensions of the operand need to be 1. We add a cast to
// cast the dynamic dimensions to 1.
auto static_type = RankedTensorType::get(
llvm::SmallVector<int64_t>(operand_type.getRank(), 1), elem_type);
operand = rewriter.create<tensor::CastOp>(reshape_op.getLoc(),
static_type, operand);
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
reshape_op, result_type, operand, ArrayRef<ReassociationIndices>{});
return success();
}
// Compute the reassociation maps for the linalg operation. This will
// succeed if the reshape can be done with a single expand_shape or
// collapse_shape.
if (Optional<SmallVector<ReassociationIndices>> reassociation_map =
getReassociationIndicesForReshape(operand_type, result_type)) {
if (result_type.getRank() < operand_type.getRank()) {
// We have found a working reassociation map. If the operand is dynamic,
// we first need to cast all unknown dimensions in the input that get
// collapsed to a static-sized dimension in the output, to 1.
SmallVector<int64_t> shape(operand_type.getShape().begin(),
operand_type.getShape().end());
for (const auto& map : llvm::enumerate(*reassociation_map)) {
// If the result dim is dynamic, we do not mind dynamic entries in the
// source.
if (result_type.isDynamicDim(map.index())) continue;
for (auto target_dim : map.value()) {
if (shape[target_dim] == ShapedType::kDynamicSize)
shape[target_dim] = 1;
}
}
auto new_operand_type = RankedTensorType::get(shape, elem_type);
if (new_operand_type != operand_type) {
operand = rewriter.create<tensor::CastOp>(reshape_op.getLoc(),
new_operand_type, operand);
}
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
reshape_op, result_type, operand, *reassociation_map);
} else {
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
reshape_op, result_type, operand, *reassociation_map);
}
return success();
}
Value collapsed_op = operand;
Location loc = reshape_op.getLoc();
auto get_identity_exprs = [&rewriter](int64_t n) {
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < n; ++i) exprs.push_back(rewriter.getAffineDimExpr(i));
return exprs;
};
// Otherwise, we need to first reduce all source dimensions into one and
// then expand to the destination dimensions. If there is only a single
// source dimension, the reduce step can be skipped. TensorCollapseShape
// expects a different rank of operand and result.
if (operand_type.getRank() != 1) {
SmallVector<ReassociationExprs, 4> collapsing_map = {
// Use operand_type here because we need to collapse all operands
// dimensions.
get_identity_exprs(operand_type.getRank())};
collapsed_op = rewriter.create<tensor::CollapseShapeOp>(loc, operand,
collapsing_map);
}
// Cast to a known static type if the input has dynamic dimensions.
int64_t total_elems = result_type.getNumElements();
auto collapsed_type = RankedTensorType::get({total_elems}, elem_type);
collapsed_op =
rewriter.create<tensor::CastOp>(loc, collapsed_type, collapsed_op);
if (result_type.getRank() == 1) {
rewriter.replaceOp(reshape_op, collapsed_op);
} else {
SmallVector<ReassociationExprs, 4> expanding_map = {
// Use result_type here because we need to expand to all result
// dimensions.
get_identity_exprs(result_type.getRank())};
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
reshape_op, result_type, collapsed_op, expanding_map);
}
return success();
}
};
template <typename OpTy>
class IotaConverter : public OpConversionPattern<OpTy> {
public:
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy iota_op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
ShapedType result_shaped_type = GetHloOpResultType(iota_op);
if (!result_shaped_type) return failure();
result_shaped_type = this->typeConverter->convertType(result_shaped_type)
.template dyn_cast<ShapedType>();
auto result_element_type = result_shaped_type.getElementType();
if (!result_element_type.isSignlessIntOrFloat()) return failure();
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = result_shaped_type.getRank();
Location loc = iota_op.getLoc();
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc,
/*resultTensorTypes=*/
ArrayRef<Type>{result_shaped_type},
/*inputs=*/ValueRange{},
/*outputBuffers=*/
ValueRange{GetInitTensorFor(rewriter, loc, result_shaped_type, iota_op,
adaptor.getOperands())},
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location nested_loc,
ValueRange /*args*/) {
Value index_op = nested_builder.create<linalg::IndexOp>(
nested_loc, iota_op.iota_dimension());
Value cast_op = nested_builder.create<arith::IndexCastOp>(
nested_loc,
nested_builder.getIntegerType(
result_element_type.getIntOrFloatBitWidth()),
index_op);
if (result_element_type.template isa<FloatType>()) {
cast_op = nested_builder.create<arith::SIToFPOp>(
nested_loc, result_element_type, cast_op);
}
nested_builder.create<linalg::YieldOp>(nested_loc, cast_op);
},
PruneAttributeList(iota_op));
rewriter.replaceOp(iota_op, linalg_op.result_tensors());
return success();
}
};
/// Converts mhlo.concatenate operation to a linalg.generic op.
struct ConcatenateConverter : public OpConversionPattern<mhlo::ConcatenateOp> {
using OpConversionPattern<mhlo::ConcatenateOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConcatenateOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
// Shortcut the one-operand case, simplifies code below.
if (adaptor.getOperands().size() == 1) {
rewriter.replaceOp(op, adaptor.getOperands()[0]);
return success();
}
auto result_type =
this->typeConverter->convertType(op.getResult().getType())
.dyn_cast<RankedTensorType>();
if (!result_type) return failure();
uint64_t dim = op.dimension();
Location loc = op.getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
// Allocate the output tensor with init_tensor.
Value result =
GetInitTensorFor(rewriter, loc, result_type, op, adaptor.getOperands());
// Generate a generic op to gather the elements of the concatenate. This is
// awkward standalone but allows fusion with other generic ops.
int64_t nloops = result_type.getRank();
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op,
/*resultTensorTypes=*/result_type,
/*inputs=*/ValueRange{}, /*outputBuffers=*/result,
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nested_builder, Location loc, ValueRange) {
OpBuilder b = nested_builder;
Value concat_dim_size = zero;
Value result;
SmallVector<Value, 4> extract_indices;
extract_indices.reserve(nloops);
for (int64_t i = 0; i < nloops; i++) {
extract_indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
Value index_op = b.create<linalg::IndexOp>(loc, dim);
for (auto& it : llvm::enumerate(adaptor.getOperands())) {
Value arg = it.value();
Value new_concat_dim_size;
scf::IfOp if_op;
if (it.index() != (adaptor.getOperands().size() - 1)) {
// Calculate how far along we have iterated along the concatenate
// dimension. That way we can tell which input to select.
new_concat_dim_size = b.create<arith::AddIOp>(
loc, concat_dim_size, b.create<tensor::DimOp>(loc, arg, dim));
Value cmp = b.create<arith::CmpIOp>(
loc, rewriter.getI1Type(), arith::CmpIPredicate::ult,
index_op, new_concat_dim_size);
if_op = b.create<scf::IfOp>(loc, result_type.getElementType(),
cmp, true);
if (result) {
b.create<scf::YieldOp>(loc, if_op->getResults()[0]);
} else {
result = if_op->getResults()[0];
}
b = if_op.getThenBodyBuilder(b.getListener());
}
// Now adjust the index for the concatenated dimension to fit into
// the selected tensor and do an extract at that position.
extract_indices[dim] =
b.create<arith::SubIOp>(loc, index_op, concat_dim_size);
Value extract =
b.create<tensor::ExtractOp>(loc, arg, extract_indices);
b.create<scf::YieldOp>(loc, extract);
if (if_op) {
b = if_op.getElseBodyBuilder(b.getListener());
concat_dim_size = new_concat_dim_size;
}
}
nested_builder.create<linalg::YieldOp>(loc, result);
},
PruneAttributeList(op));
return success();
}
};
class ConstConverterTensor : public OpConversionPattern<mhlo::ConstOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConstOp const_op, OpAdaptor /*adaptor*/,
ConversionPatternRewriter& rewriter) const final {
auto value_attr = const_op.value().cast<DenseElementsAttr>();
auto type =
typeConverter->convertType(const_op.getType()).cast<ShapedType>();
if (type != const_op.getType()) {
// Signedness conversion.
value_attr = value_attr.mapValues(type.getElementType(),
[](const APInt& i) { return i; });
}
rewriter.replaceOpWithNewOp<arith::ConstantOp>(const_op, type, value_attr);
return success();
}
};
// TODO(b/156787842): Support the lowering for dynamic shapes.
class ReverseConverter
: public DataMovementOpConverter<ReverseConverter, mhlo::ReverseOp> {
public:
using DataMovementOpConverter<ReverseConverter,
mhlo::ReverseOp>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps(mhlo::ReverseOp op,
Builder* b) {
auto result_type = GetHloOpResultType(op).cast<ShapedType>();
auto nloops = result_type.getRank();
SmallVector<AffineExpr, 2> input_exprs;
input_exprs.reserve(nloops);
for (int i = 0; i < nloops; ++i)
input_exprs.push_back(b->getAffineDimExpr(i));
for (auto dim : op.dimensions()) {
int i = dim.getZExtValue();
if (result_type.isDynamicDim(i)) return {};
int n = result_type.getShape()[i];
input_exprs[i] = b->getAffineConstantExpr(n - 1) - input_exprs[i];
}
return {
AffineMap::get(nloops, /*symbolCount=*/0, input_exprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
class SliceConverter : public OpConversionPattern<mhlo::SliceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::SliceOp slice_op, typename mhlo::SliceOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto arg_type = adaptor.getOperands()[0].getType().dyn_cast<ShapedType>();
if (!arg_type || !arg_type.hasRank()) {
return rewriter.notifyMatchFailure(slice_op, "expects known-rank args");
}
SmallVector<OpFoldResult, 3> offsets, sizes, strides;
for (int i = 0, e = arg_type.getRank(); i < e; ++i) {
auto start = slice_op.start_indices().getValues<int64_t>()[i];
auto limit = slice_op.limit_indices().getValues<int64_t>()[i];
auto stride = slice_op.strides().getValues<int64_t>()[i];
offsets.push_back(rewriter.getI64IntegerAttr(start));
// Say that there are k elements in total, we have condition:
// start + (k - 1) * strides <= limit - 1
// ->
// k <= (limit - 1 - start) / strides + 1
sizes.push_back(
rewriter.getI64IntegerAttr((limit - 1 - start) / stride + 1));
strides.push_back(rewriter.getI64IntegerAttr(stride));
}
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
slice_op, adaptor.getOperands()[0], offsets, sizes, strides);
return success();
}
};
class DynamicSliceConverter : public OpConversionPattern<mhlo::DynamicSliceOp> {
public:
using OpConversionPattern<mhlo::DynamicSliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DynamicSliceOp dynamic_slice_op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto loc = dynamic_slice_op.getLoc();
auto arg_type = adaptor.operand().getType().dyn_cast<ShapedType>();
if (!arg_type || !arg_type.hasRank()) {
return rewriter.notifyMatchFailure(dynamic_slice_op,
"require known-rank args");
}
auto index_type = rewriter.getIndexType();
SmallVector<OpFoldResult, 3> start_indices, sizes;
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(adaptor.start_indices()[0]
.getType()
.cast<RankedTensorType>()
.getElementType()));
for (auto& en : llvm::enumerate(
llvm::zip(adaptor.start_indices(),
dynamic_slice_op.slice_sizes().getValues<int64_t>()))) {
int64_t size = std::get<1>(en.value());
sizes.push_back(rewriter.getI64IntegerAttr(size));
// By mhlo.DynamicSlice definition:
// `start_indices[i] = clamp(start_indices[i],
// 0, operand.dimension_size[i] - size_indices[i])`
Value start_index =
rewriter.create<tensor::ExtractOp>(loc, std::get<0>(en.value()));
Value ub = rewriter.createOrFold<tensor::DimOp>(loc, adaptor.operand(),
en.index());
// ClampOp lowering does not support index type, so cast it into integer
// type.
ub = rewriter.createOrFold<arith::IndexCastOp>(loc, start_index.getType(),
ub);
ub = rewriter.createOrFold<arith::SubIOp>(
loc, ub,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(start_index.getType(), size)));
start_index = mhlo::MhloOpToStdScalarOp::map<mhlo::ClampOp>(
loc, start_index.getType(),
ArrayRef<Type>{start_index.getType(), start_index.getType(),
start_index.getType()},
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
start_indices.push_back(
rewriter.create<arith::IndexCastOp>(loc, index_type, start_index)
.getResult());
}
int64_t rank = arg_type.getRank();
SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
auto result_type =
this->typeConverter->convertType(dynamic_slice_op.getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
dynamic_slice_op, result_type, adaptor.operand(), start_indices, sizes,
strides);
return success();
}
};
class DynamicUpdateSliceConverter
: public OpConversionPattern<mhlo::DynamicUpdateSliceOp> {
public:
using OpConversionPattern<mhlo::DynamicUpdateSliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
auto operand_type =
adaptor.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_type || !operand_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "require static ranked type for operand");
}
auto update_type = adaptor.update().getType().dyn_cast<RankedTensorType>();
if (!update_type || !update_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "require static ranked type for operand");
}
// We do not have to clamp sizes because the semantic of `update`
// guarantees that it is always in the bounds. See
// https://www.tensorflow.org/xla/operation_semantics#dynamicupdateslice
SmallVector<OpFoldResult, 3> sizes;
for (auto size : update_type.getShape()) {
sizes.push_back(rewriter.getIndexAttr(size));
}
auto index_type = rewriter.getIndexType();
SmallVector<OpFoldResult, 3> start_indices;
Type start_index_type = adaptor.start_indices()[0]
.getType()
.cast<RankedTensorType>()
.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(start_index_type));
for (auto& en : llvm::enumerate(adaptor.start_indices())) {
// By mhlo.DynamicUpdateSlice definition:
// `start_indices[i] = clamp(start_indices[i],
// 0, operand.dimension_size[i] - update.dimension_size[i])`
Value start_index = rewriter.create<tensor::ExtractOp>(loc, en.value());
Value ub = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(start_index_type,
operand_type.getDimSize(en.index()) -
update_type.getDimSize(en.index())));
start_index = mhlo::MhloOpToStdScalarOp::map<mhlo::ClampOp>(
loc, start_index_type,
ArrayRef<Type>{start_index_type, start_index_type, start_index_type},
ArrayRef<Value>{zero, start_index, ub}, &rewriter);
start_indices.push_back(
rewriter.create<arith::IndexCastOp>(loc, index_type, start_index)
.getResult());
}
int64_t rank = operand_type.getRank();
SmallVector<OpFoldResult, 3> strides(rank, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
op, adaptor.update(), adaptor.operand(), start_indices, sizes, strides);
return success();
}
};
enum class DotOperationType {
kVectorDot = 0,
kMatrixVector,
kVectorMatrix,
kMatrixMatrix,
kUnsupported
};
DotOperationType GetDotOperationType(mhlo::DotOp dot_op) {
ArrayRef<int64_t> lhs_shape =
dot_op.lhs().getType().cast<ShapedType>().getShape();
ArrayRef<int64_t> rhs_shape =
dot_op.rhs().getType().cast<ShapedType>().getShape();
auto shape_matches = [](int64_t a, int64_t b) {
return a == ShapedType::kDynamicSize || b == ShapedType::kDynamicSize ||
a == b;
};
if (lhs_shape.size() == 1 && rhs_shape.size() == 1 &&
shape_matches(lhs_shape[0], rhs_shape[0])) {
return DotOperationType::kVectorDot;
}
if (lhs_shape.size() == 2 && rhs_shape.size() == 1 &&
shape_matches(lhs_shape[1], rhs_shape[0])) {
return DotOperationType::kMatrixVector;
}
if (lhs_shape.size() == 1 && rhs_shape.size() == 2 &&
shape_matches(lhs_shape[0], rhs_shape[0])) {
return DotOperationType::kVectorMatrix;
}
if (lhs_shape.size() == 2 && rhs_shape.size() == 2 &&
shape_matches(lhs_shape[1], rhs_shape[0])) {
return DotOperationType::kMatrixMatrix;
}
return DotOperationType::kUnsupported;
}
SmallVector<Value, 2> GetDotOpInitTensorDynSizes(OpBuilder& b, Location loc,
Value lhs, Value rhs,
DotOperationType type) {
SmallVector<Value, 2> dyn_shape;
switch (type) {
case DotOperationType::kMatrixMatrix: {
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kMatrixVector: {
if (lhs.getType().cast<ShapedType>().isDynamicDim(0))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
break;
}
case DotOperationType::kVectorMatrix: {
if (rhs.getType().cast<ShapedType>().isDynamicDim(1))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 1));
break;
}
case DotOperationType::kVectorDot:
case DotOperationType::kUnsupported:
default: {
break;
}
}
return dyn_shape;
}
template <DotOperationType op_type, typename LinalgOp>
class DotOpConversion : public OpConversionPattern<mhlo::DotOp> {
public:
using OpConversionPattern<mhlo::DotOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DotOp op, mhlo::DotOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics(op)) {
return failure();
}
if (GetDotOperationType(op) != op_type) return failure();
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
SmallVector<Value, 2> dyn_shape = GetDotOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), op_type);
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
rewriter.replaceOpWithNewOp<LinalgOp>(
op, TypeRange{op.getType()}, ValueRange{adaptor.lhs(), adaptor.rhs()},
ValueRange{zero_tensor}, PruneAttributeList(op));
return success();
}
};
SmallVector<Value, 8> GetDotGeneralOpInitTensorDynSizes(
OpBuilder& b, Location loc, Value lhs, Value rhs, ShapedType result_type) {
SmallVector<Value, 8> dyn_shape;
if (result_type.isDynamicDim(0))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 0));
if (result_type.isDynamicDim(1))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, lhs, 1));
if (result_type.isDynamicDim(2))
dyn_shape.push_back(b.create<tensor::DimOp>(loc, rhs, 2));
return dyn_shape;
}
class DotGeneralBatchMatMulOpConversion
: public OpConversionPattern<mhlo::DotGeneralOp> {
public:
using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DotGeneralOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics(op)) {
return failure();
}
mhlo::DotDimensionNumbersAttr dim_numbers = op.dot_dimension_numbers();
auto lhs_batching_dims = dim_numbers.getLhsBatchingDimensions();
auto rhs_batching_dims = dim_numbers.getRhsBatchingDimensions();
auto lhs_contracting_dims = dim_numbers.getLhsContractingDimensions();
auto rhs_contracting_dims = dim_numbers.getRhsContractingDimensions();
if (lhs_batching_dims.size() != 1 || lhs_batching_dims[0] != 0) {
return rewriter.notifyMatchFailure(
op, "expected lhs batching dimensions exactly {0}");
}
if (rhs_batching_dims.size() != 1 || rhs_batching_dims[0] != 0) {
return rewriter.notifyMatchFailure(
op, "expected rhs batching dimensions exactly {0}");
}
if (lhs_contracting_dims.size() != 1 || lhs_contracting_dims[0] != 2) {
return rewriter.notifyMatchFailure(
op, "expected lhs contracting dimensions exactly {2}");
}
if (rhs_contracting_dims.size() != 1 || rhs_contracting_dims[0] != 1) {
return rewriter.notifyMatchFailure(
op, "expected rhs contracting dimensions exactly {1}");
}
Location loc = op.getLoc();
auto output_type = op.getType().cast<ShapedType>();
auto output_el_type = output_type.getElementType();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
Operation* linalg_op = rewriter.create<linalg::BatchMatmulOp>(
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{zero_tensor}, PruneAttributeList(op));
rewriter.replaceOp(op, linalg_op->getResults());
return success();
}
};
bool IsInBodyOfLinalgOps(Operation* op) {
auto* parent_op = op->getParentRegion()->getParentOp();
return parent_op->getDialect() ==
parent_op->getContext()->getLoadedDialect<linalg::LinalgDialect>();
}
template <typename OpTy>
struct ReduceRegionXLAOpConversion : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult matchAndRewrite(
OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!IsInBodyOfLinalgOps(op)) {
return failure();
}
if (!op.getResult().getType().template isa<TensorType>()) return failure();
if (llvm::all_of(adaptor.getOperands(), [](Value arg) {
return arg.getType().template isa<TensorType>();
})) {
return failure();
}
Value result = mhlo::MhloOpToStdScalarOp::map<OpTy>(
op, getElementTypeOrSelf(op.getType()), adaptor.getOperands(),
&rewriter);
rewriter.replaceOp(op, result);
return success();
}
};
SmallVector<Value, 8> GetReduceOpInitTensorDynSizes(
OpBuilder& b, Location loc, Value arg, ShapedType result_type,
ArrayRef<int64_t> reduction_dims) {
llvm::SmallSetVector<int, 4> s;
for (auto dim : reduction_dims) s.insert(dim);
SmallVector<unsigned, 4> parallel_dims;
SmallVector<Value, 8> dyn_shape;
int rank = arg.getType().cast<RankedTensorType>().getRank();
for (int i = 0, j = 0; i < rank; ++i) {
if (s.count(i)) continue;
if (!result_type.isDynamicDim(j++)) continue;
dyn_shape.push_back(b.create<tensor::DimOp>(loc, arg, i));
}
return dyn_shape;
}
class ReduceRegionReturnOpConversion
: public OpConversionPattern<mhlo::ReturnOp> {
public:
using OpConversionPattern<mhlo::ReturnOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!IsInBodyOfLinalgOps(op)) {
return failure();
}
SmallVector<Value, 4> operands(adaptor.getOperands());
for (size_t i = 0; i < operands.size(); ++i) {
if (operands[i].getType().isa<ShapedType>()) {
auto loc = operands[i].getLoc();
operands[i] = rewriter.create<tensor::ExtractOp>(loc, operands[i]);
}
}
rewriter.replaceOpWithNewOp<linalg::YieldOp>(op, operands);
return success();
}
};
class ReduceConversion : public OpConversionPattern<mhlo::ReduceOp> {
public:
using OpConversionPattern<mhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
Location loc = op.getLoc();
int num_inputs = static_cast<int>(adaptor.inputs().size());
if (llvm::any_of(adaptor.inputs(), [](Value v) {
return !v.getType().cast<ShapedType>().getRank();
})) {
return rewriter.notifyMatchFailure(op, "expects known-rank args");
}
auto src_rank = adaptor.inputs()[0].getType().cast<ShapedType>().getRank();
SmallVector<int64_t, 4> reduction_dims = Extract1DVector(op.dimensions());
SmallVector<Value> inputs, outputs;
SmallVector<AffineMap, 3> indexing_maps;
for (auto values :
llvm::zip(adaptor.inputs(), adaptor.init_values(), op.getResults())) {
// Check if init_value is constant. If so, inline the value into the
// region.
Value input = std::get<0>(values);
Value init_value = std::get<1>(values);
Value result = std::get<2>(values);
Attribute init_const_val = GetInitValueAsConst(init_value);
if (init_const_val) {
init_value = rewriter.create<arith::ConstantOp>(
init_value.getDefiningOp()->getLoc(), init_const_val);
} else {
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
}
inputs.push_back(input);
auto result_type = result.getType().cast<ShapedType>();
SmallVector<Value, 8> dyn_shape = GetReduceOpInitTensorDynSizes(
rewriter, loc, input, result_type, reduction_dims);
auto init_tensor = GetInitTensor(rewriter, loc, result_type, dyn_shape);
Value filled_tensor =
rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
.result();
outputs.push_back(filled_tensor);
}
// Prepare indexing maps for linalg generic op. The elements are for src
// and dst. Transpose `src` to make the reduction loops be the innermost,
// because it's easier to fully utilize processors.
indexing_maps.append(
num_inputs, GetTransposeMapForReduction(rewriter.getContext(), src_rank,
reduction_dims));
// The indexing map of `dst` should drop the reduction loops. Since the
// reduction loops now are all in the innermost, drops
// `reduction_dims.size()` dimensions. We don't need an inverse
// permutation here because they are the same.
SmallVector<AffineExpr, 4> exprs;
for (int i = 0, e = src_rank - reduction_dims.size(); i < e; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
indexing_maps.append(num_inputs,
AffineMap::get(src_rank, /*symbolCount=*/0, exprs,
rewriter.getContext()));
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/op.getResultTypes(), inputs,
/*outputBuffers=*/ValueRange{outputs}, indexing_maps,
GetParallelAndReductionIterators(src_rank, reduction_dims.size()),
/*bodyBuild=*/nullptr, PruneAttributeList(op));
// Convert the signature of the body. The reduce op region apply function
// has a signature (lhs, rhs) -> output, all of the same tensor type t.
// This is converted to a function with the same signature but with
// element types. E.g., "(tensor<f32>, tensor<f32>) -> tensor<f32>" will
// be converted to "(f32, f32, f32)".
Region& region = linalg_op.region();
rewriter.inlineRegionBefore(op.body(), region, region.end());
TypeConverter::SignatureConversion signature_converter(num_inputs * 2);
// map input and init values's types
for (const auto& it : llvm::enumerate(op.getOperation()->getOperands())) {
signature_converter.addInputs(
it.index(), it.value().getType().cast<ShapedType>().getElementType());
}
rewriter.applySignatureConversion(&region, signature_converter);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
/// Converts mhlo.pad operation to linalg.pad_tensor op.
struct PadOpConversion : public OpConversionPattern<mhlo::PadOp> {
using OpConversionPattern<mhlo::PadOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::PadOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (llvm::any_of(
op.interior_padding().getValues<APInt>(),
[](const APInt& int_val) { return int_val.getZExtValue() != 0; })) {
return rewriter.notifyMatchFailure(op, "expected no interior padding");
}
auto loc = op.getLoc();
Value padding_val =
rewriter.createOrFold<tensor::ExtractOp>(loc, adaptor.padding_value());
SmallVector<OpFoldResult, 4> low(
op.edge_padding_low().getValues<IntegerAttr>());
SmallVector<OpFoldResult, 4> high(
op.edge_padding_high().getValues<IntegerAttr>());
Type result_type = op.getResult().getType();
auto pad_tensor_op = tensor::createPadScalarOp(
result_type, adaptor.operand(), padding_val, low, high,
/*nofold=*/false, loc, rewriter);
rewriter.replaceOp(op, pad_tensor_op.getResult());
return success();
}
};
/// Apply padding values stored in `pad` to `input`.
static Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
Attribute padAttr, OpBuilder& rewriter) {
auto inputTy = input.getType().cast<ShapedType>();
Type inputETy = inputTy.getElementType();
ArrayRef<int64_t> inputShape = inputTy.getShape();
assert((inputShape.size() * 2) == pad.size() &&
"There should be 2 padding values per dimension, i.e low and high.");
SmallVector<int64_t> paddedShape;
SmallVector<OpFoldResult> lowIndices;
SmallVector<OpFoldResult> highIndices;
for (int i : llvm::seq<int>(0, inputShape.size())) {
int64_t lowPad = pad[i * 2];
int64_t highPad = pad[i * 2 + 1];
paddedShape.push_back(inputShape[i] + highPad + lowPad);
lowIndices.push_back(rewriter.getIndexAttr(lowPad));
highIndices.push_back(rewriter.getIndexAttr(highPad));
}
Value padValue = rewriter.create<arith::ConstantOp>(loc, padAttr);
return tensor::createPadScalarOp(RankedTensorType::get(paddedShape, inputETy),
input, padValue, lowIndices, highIndices,
/*nofold=*/false, loc, rewriter);
}
/// Converts mhlo.conv operation to linalg named op. This only covers normal
/// convolution cases. The op must have canonical dimension numbers. Depthwise
/// convolution and pointwise convolution are not handled in the conversion.
struct NormalConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConvOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
if (op.feature_group_count() != 1u) return failure();
Location loc = op.getLoc();
Value input = adaptor.lhs();
Value filter = adaptor.rhs();
auto result_type = op.getResult().getType().cast<ShapedType>();
int64_t rank = result_type.getRank();
// The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes;
if (result_type.isDynamicDim(0)) {
dyn_sizes.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
}
for (int64_t i = 1, e = rank - 1; i < e; ++i) {
if (result_type.isDynamicDim(i)) {
return rewriter.notifyMatchFailure(
op, "expected output spatial dims to be static shapes");
}
}
if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, filter, rank - 1));
}
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
linalg::LinalgOp res;
Attribute strides = op.window_stridesAttr();
// TODO(ataei): Only support dilated kernel right now. We need to consider
// input dilation for deconvolution cases.
Attribute dilations = op.rhs_dilationAttr();
// Check if padding is zero or not. If it is not zero, we should pad the
// input.
DenseIntElementsAttr padding = op.paddingAttr();
if (padding && !isSplatValue(padding, 0)) {
// Add the zero padding for the batch dim.
SmallVector<int64_t> pad(2, 0);
// Add the padding values.
pad.append(Extract1DVector(padding));
// Add the zero padding for the feature dim.
pad.append(2, 0);
// Pad the given input using `zero_attr` according to the low and high
// values in the `pad`.
input = applyPad(loc, input, pad, zero_attr, rewriter);
}
switch (rank) {
case 2: {
res = rewriter.create<linalg::MatmulOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, PruneAttributeList(op));
break;
}
case 3: {
res = rewriter.create<linalg::Conv1DNwcWcfOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, strides, dilations,
PruneAttributeList(op));
break;
}
case 4: {
res = rewriter.create<linalg::Conv2DNhwcHwcfOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, strides, dilations,
PruneAttributeList(op));
break;
}
case 5: {
res = rewriter.create<linalg::Conv3DNdhwcDhwcfOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, strides, dilations,
PruneAttributeList(op));
break;
}
default:
return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
}
rewriter.replaceOp(op, res.getOperation()->getResults());
return success();
}
};
/// Converts mhlo.convolution operation to
/// linalg.depthwise_conv_2d_input_nhwc_filter_hwcf op or
/// depthwise_conv_2d_input_nhwc_filter_hwc op.
struct DepthwiseConvOpConversion : public OpConversionPattern<mhlo::ConvOp> {
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConvOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (op.batch_group_count() != 1) return failure();
// Fall into the normal convolution cases.
if (op.feature_group_count() == 1) return failure();
if ((op.lhs_dilation() && !isSplatValue(*op.lhs_dilation(), 1))) {
return rewriter.notifyMatchFailure(
op, "non-one lhs- dialation unsupported yet");
}
if (const mhlo::ConvDimensionNumbersAttr& dimension_numbers =
op.dimension_numbers()) {
// Make sure that this is 2-D convolution.
const auto spatial_rank =
llvm::size(dimension_numbers.getInputSpatialDimensions());
if (spatial_rank != 2) {
return rewriter.notifyMatchFailure(op,
"only support 2-D cases for now");
}
// Make sure that this is depthwise convolution.
int64_t input_feature_dim = dimension_numbers.getInputFeatureDimension();
int64_t input_feature_count =
op.lhs().getType().cast<ShapedType>().getDimSize(input_feature_dim);
if (op.feature_group_count() != input_feature_count) {
return rewriter.notifyMatchFailure(op, "not depth-wise convolution");
}
// Make sure that this convolution has a canonical form.
if (!HasCanonicalDimensionNumbers(dimension_numbers)) {
return rewriter.notifyMatchFailure(op, "does not have canonical form");
}
}
DenseIntElementsAttr window_strides;
if (op.window_strides()) {
window_strides = op.window_strides().getValue();
} else {
window_strides = rewriter.getI64VectorAttr({1, 1});
}
DenseIntElementsAttr rhs_dilation;
if (op.rhs_dilation()) {
rhs_dilation = op.rhs_dilation().getValue();
} else {
rhs_dilation = rewriter.getI64VectorAttr({1, 1});
}
Location loc = op.getLoc();
Value input = adaptor.lhs();
Value filter = adaptor.rhs();
auto result_type = op.getResult().getType().cast<RankedTensorType>();
if (!result_type.hasStaticShape()) {
return rewriter.notifyMatchFailure(op,
"expected output has static shapes");
}
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
// Check if padding is zero or not. If it is not zero, we should pad the
// input.
DenseIntElementsAttr padding = op.paddingAttr();
if (padding && !isSplatValue(padding, 0)) {
// Add the zero padding for the batch dim.
SmallVector<int64_t> pad(2, 0);
// Add the padding values.
pad.append(Extract1DVector(padding));
// Add the zero padding for the feature dim.
pad.append(2, 0);
// Pad the given input using `zero_attr` according to the low and high
// values in the `pad`.
input = applyPad(loc, input, pad, zero_attr, rewriter);
}
auto filter_dims =
llvm::to_vector<4>(op.rhs().getType().cast<ShapedType>().getShape());
auto get_indices_vector = [](int start, int end) {
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
};
if (filter_dims[2] * filter_dims[3] != op.feature_group_count()) {
// For cases where channel multiplier != 1
auto output_dims = result_type.getShape();
auto channel_multiplier = filter_dims[3];
SmallVector<int64_t> reshaped_output_dims;
reshaped_output_dims.assign(output_dims.begin(), output_dims.end());
reshaped_output_dims.push_back(channel_multiplier);
reshaped_output_dims[3] /= channel_multiplier;
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, reshaped_output_dims, result_type.getElementType());
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
auto reshaped_output_type = RankedTensorType::get(
reshaped_output_dims, result_type.getElementType());
auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
op.getLoc(), reshaped_output_type, ValueRange{input, filter},
ValueRange{zero_tensor}, window_strides, rhs_dilation,
PruneAttributeList(op));
// Create a Linalg reshape op that converts the output from 5 dimensions
// into 4 dimensions (by collapsing the last two dimensions). This is
// needed because linalg.depthwise_conv_2d_input_nhwc_filter_hwcf returns
// 5 dimensions for the output.
SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 3), get_indices_vector(3, 5)};
rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
op, result_type, conv.getResult(0), collapsed_dim_list);
} else {
// For cases where channel multiplier == 1
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, result_type.getShape(), result_type.getElementType());
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
// Create a Linalg reshape op that converts the filter from 4 dimensions
// into 3 dimensions (by droping the unit dimension). This is needed
// because linalg.depthwise_conv_2d_input_nhwc_filter_hwc expects 3
// dimensions for the filter.
filter_dims[2] = static_cast<int64_t>(op.feature_group_count());
filter_dims.pop_back();
RankedTensorType filter_shape =
RankedTensorType::get(filter_dims, op.getType().getElementType());
SmallVector<ReassociationIndices, 4> collapsed_dim_list = {
get_indices_vector(0, 1), get_indices_vector(1, 2),
get_indices_vector(2, 4)};
Value reshaped_filter = rewriter.create<tensor::CollapseShapeOp>(
loc, filter_shape, filter, collapsed_dim_list);
rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(
op, result_type, ValueRange{input, reshaped_filter},
ValueRange{zero_tensor}, window_strides, rhs_dilation,
PruneAttributeList(op));
}
return success();
}
};
struct ReduceWindowOpOnTensorsGenericConversion
: public OpConversionPattern<mhlo::ReduceWindowOp> {
using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ReduceWindowOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
MLIRContext* ctx = op->getContext();
Location loc = op.getLoc();
llvm::SmallVector<Value> init_values = adaptor.init_values();
llvm::SmallVector<Type> result_types = llvm::to_vector(op.getResultTypes());
auto num_operands = init_values.size();
llvm::SmallVector<int64_t> window_dimensions =
Extract1DVector(op.window_dimensions());
llvm::SmallVector<int64_t> padding;
if (op.padding()) {
padding = Extract1DVector(*op.padding());
}
llvm::SmallVector<int64_t> base_dilations;
if (op.window_dilations()) {
base_dilations = Extract1DVector(*op.base_dilations());
if (llvm::any_of(base_dilations, [](int64_t& x) { return x != 1; }))
return failure();
}
llvm::SmallVector<int64_t> window_strides(window_dimensions.size(), 1);
if (op.window_strides()) {
window_strides = Extract1DVector(*op.window_strides());
}
llvm::SmallVector<int64_t> window_dilations(window_dimensions.size(), 1);
if (op.window_dilations()) {
window_dilations = Extract1DVector(*op.window_dilations());
}
auto rank = window_dimensions.size();
SmallVector<AffineExpr, 2> src_exprs;
SmallVector<AffineExpr, 2> window_exprs;
SmallVector<AffineExpr, 2> dst_exprs;
SmallVector<int64_t> filtered_window_dims;
int window_dim = 0;
for (int i = 0; i < rank; i++) {
AffineExpr src_expr = mlir::getAffineDimExpr(i, ctx);
if (window_strides[i] != 1) src_expr = src_expr * window_strides[i];
if (window_dimensions[i] != 1) {
filtered_window_dims.push_back(window_dimensions[i]);
AffineExpr window_expr = mlir::getAffineDimExpr(rank + window_dim, ctx);
window_exprs.push_back(window_expr);
if (window_dilations[i] != 1)
window_expr = window_expr * window_dilations[i];
src_expr = src_expr + window_expr;
window_dim++;
}
src_exprs.push_back(src_expr);
dst_exprs.push_back(mlir::getAffineDimExpr(i, ctx));
}
SmallVector<AffineMap, 4> inferred_maps =
AffineMap::inferFromExprList({src_exprs, window_exprs, dst_exprs});
SmallVector<AffineMap, 4> indexing_maps;
indexing_maps.append(num_operands, inferred_maps[0]);
indexing_maps.append(1, inferred_maps[1]);
indexing_maps.append(num_operands, inferred_maps[2]);
// Setup the initial values.
llvm::SmallVector<Value> broadcast_values;
for (uint64_t i = 0, s = init_values.size(); i < s; i++) {
Value init_value = init_values[i];
auto result_ty = result_types[i].cast<ShapedType>();
if (!result_ty.hasStaticShape()) return failure();
auto broadcast_sizes = rewriter.getI64TensorAttr(result_ty.getShape());
broadcast_values.push_back(rewriter.create<mhlo::BroadcastOp>(
loc, result_ty, init_value, broadcast_sizes));
}
llvm::SmallVector<Value> inputs = llvm::to_vector(adaptor.inputs());
// Pad as necessary.
if (llvm::any_of(padding, [](int32_t v) { return v != 0; })) {
llvm::SmallVector<int64_t> static_lows;
llvm::SmallVector<int64_t> static_highs;
for (int i = 0; i < padding.size(); i += 2) {
static_lows.push_back(padding[i]);
static_highs.push_back(padding[i + 1]);
}
for (auto values : llvm::zip(inputs, init_values)) {
auto& input = std::get<0>(values);
auto& init_value = std::get<1>(values);
// Extract the single element from init value. This mimic the lowering
// behavior of mhlo.pad.
Value padding_value =
rewriter.createOrFold<tensor::ExtractOp>(loc, init_value);
auto pad_op = rewriter.create<tensor::PadOp>(
loc, input, static_lows, static_highs, ValueRange{}, ValueRange{});
SmallVector<Type, 4> block_arg_types;
block_arg_types.assign(input.getType().cast<ShapedType>().getRank(),
rewriter.getIndexType());
auto& region = pad_op.region();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(
&region, region.end(), block_arg_types,
SmallVector<Location>(block_arg_types.size(), loc));
rewriter.create<tensor::YieldOp>(loc, padding_value);
input = pad_op.getResult();
}
}
// Add the extra input for the reduction dimension.
inputs.push_back(rewriter.create<linalg::InitTensorOp>(
loc, filtered_window_dims, rewriter.getF32Type()));
rewriter.setInsertionPoint(op);
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensors=*/result_types,
/*inputs=*/inputs,
/*outputs=*/broadcast_values, indexing_maps,
GetParallelAndReductionIterators(rank + filtered_window_dims.size(),
filtered_window_dims.size()),
/*bodyBuild=*/nullptr, PruneAttributeList(op));
// Convert the signature of the body. This includes converting scalar
// tensors to their scalar values and inserting an additional block arg for
// the window arg.
Region& region = linalg_op.region();
rewriter.cloneRegionBefore(op.body(), region, region.end());
TypeConverter::SignatureConversion signature_converter(
inputs.size() + op->getNumResults() - 1);
for (uint64_t i = 0, s = inputs.size(); i < s - 1; i++) {
signature_converter.addInputs(
i, inputs[i].getType().cast<ShapedType>().getElementType());
}
signature_converter.addInputs(
inputs.back().getType().cast<ShapedType>().getElementType());
for (uint64_t i = 0, s = result_types.size(); i < s; i++) {
auto idx = inputs.size() + i - 1;
signature_converter.addInputs(
idx, result_types[i].cast<ShapedType>().getElementType());
}
rewriter.applySignatureConversion(&region, signature_converter);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
struct ReduceWindowOpConversion
: public OpConversionPattern<mhlo::ReduceWindowOp> {
using OpConversionPattern<mhlo::ReduceWindowOp>::OpConversionPattern;
/// mhlo.reduce_window is mapped to a linalg.pooling operation. The type of
/// the pooling is determined based on the body of the reduce window
/// operation. This class enumerates the different variants.
enum class PoolingType {
kInvalid,
k2DMin,
k3DMin,
k2DMax,
k3DMax,
k2DAdd,
k3DAdd,
};
static PoolingType getPoolingType(mhlo::ReduceWindowOp reduce_op,
int result_index) {
auto rank =
reduce_op.getResultTypes()[result_index].cast<ShapedType>().getRank();
if (Operation* op = reduce_op.getReductionOp(result_index)) {
if (isa<mhlo::MinOp>(*op) && rank == 4) return PoolingType::k2DMin;
if (isa<mhlo::MinOp>(*op) && rank == 5) return PoolingType::k3DMin;
if (isa<mhlo::MaxOp>(*op) && rank == 4) return PoolingType::k2DMax;
if (isa<mhlo::MaxOp>(*op) && rank == 5) return PoolingType::k3DMax;
if (isa<mhlo::AddOp>(*op) && rank == 4) return PoolingType::k2DAdd;
if (isa<mhlo::AddOp>(*op) && rank == 5) return PoolingType::k3DAdd;
}
return PoolingType::kInvalid;
}
LogicalResult matchAndRewrite(
mhlo::ReduceWindowOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
auto loc = op.getLoc();
int rank = op.getResultTypes()[0].cast<ShapedType>().getRank();
if (rank != 4 && rank != 5) {
return rewriter.notifyMatchFailure(
op, "expected NHWC/NDHWC pooling-based op");
}
if (op.padding() && !isSplatValue(*op.padding(), 0)) {
return rewriter.notifyMatchFailure(op, "require paddings are all zero");
}
int last_dim = rank - 1;
SmallVector<int64_t, 2> fake_window_shapes;
for (int i = 1; i < last_dim; ++i) {
fake_window_shapes.push_back(
op.window_dimensions().getValues<int64_t>()[i]);
}
if (op.window_strides() &&
(op.window_strides().getValue().getValues<int64_t>()[0] != 1 ||
op.window_strides().getValue().getValues<int64_t>()[last_dim] != 1)) {
return rewriter.notifyMatchFailure(
op, "expected window_strides to be [1,x,y,(z),1]");
}
if (op.window_dimensions() &&
(op.window_dimensions().getValues<int64_t>()[0] != 1 ||
op.window_dimensions().getValues<int64_t>()[last_dim] != 1)) {
return rewriter.notifyMatchFailure(
op, "expected window_dimensions to be [1,x,y,(z),1]");
}
Attribute strides;
SmallVector<int64_t> vec;
if (op.window_stridesAttr()) {
for (int i = 1; i < last_dim; ++i) {
vec.push_back(op.window_strides().getValue().getValues<int64_t>()[i]);
}
} else {
vec.assign(rank - 2, 1);
}
strides = rewriter.getI64VectorAttr(vec);
Attribute dilations;
vec.clear();
if (op.window_dilations()) {
for (int i = 1; i < last_dim; ++i) {
vec.push_back(op.window_dilations().getValue().getValues<int64_t>()[i]);
}
} else {
vec.assign(rank - 2, 1);
}
dilations = rewriter.getI64VectorAttr(vec);
SmallVector<Value> pooling_ops;
ValueRange inputs = adaptor.inputs();
ValueRange init_values = adaptor.init_values();
for (auto it : llvm::zip(op.getResults(), inputs, init_values)) {
OpResult result = std::get<0>(it);
Value input = std::get<1>(it);
Value init_value = std::get<2>(it);
auto result_type = result.getType().cast<ShapedType>();
if (!input.getType().cast<ShapedType>().getElementType().isF32()) {
return rewriter.notifyMatchFailure(op,
"expected element type to be f32");
}
// Create a fake window dimension.
auto fake_window_dims = rewriter.create<linalg::InitTensorOp>(
loc, fake_window_shapes, result_type.getElementType());
SmallVector<Value> result_dynamic_dims;
for (auto& en : llvm::enumerate(result_type.getShape())) {
if (en.value() != ShapedType::kDynamicSize) continue;
Value dim_size = rewriter.create<tensor::DimOp>(loc, input, en.index());
if (en.index() == 0 || en.index() == rank - 1) {
// batch dims and channel dims can be derived from input dims
// directly.
result_dynamic_dims.push_back(dim_size);
} else {
auto i = en.index() - 1;
auto stride =
strides.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
auto dilation =
dilations.cast<DenseIntElementsAttr>().getValues<int64_t>()[i];
// let j = i * stride
// output[i] = reduce( input[j, j + window_size * dilation) )
Value offset = rewriter.create<arith::ConstantIndexOp>(
loc, fake_window_shapes[i] * dilation);
dim_size = rewriter.create<arith::SubIOp>(loc, dim_size, offset);
dim_size = rewriter.create<arith::DivUIOp>(
loc, dim_size,
rewriter.create<arith::ConstantIndexOp>(loc, stride));
dim_size = rewriter.create<arith::AddIOp>(
loc, dim_size, rewriter.create<arith::ConstantIndexOp>(loc, 1));
result_dynamic_dims.push_back(dim_size);
}
}
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, result_dynamic_dims, result_type.getShape(),
result_type.getElementType());
init_value = rewriter.create<tensor::ExtractOp>(loc, init_value);
Value filled_init_tensor =
rewriter.create<linalg::FillOp>(loc, init_value, init_tensor)
.getResult(0);
auto create_op = [&](auto* type_ptr) -> linalg::LinalgOp {
return cast<linalg::LinalgOp>(
rewriter
.create<std::remove_pointer_t<decltype(type_ptr)>>(
loc, ArrayRef<Type>{result_type},
ValueRange{input, fake_window_dims.getResult()},
filled_init_tensor, strides, dilations,
PruneAttributeList(op))
.getOperation());
};
linalg::LinalgOp pooling_op;
PoolingType pooling_type = getPoolingType(op, result.getResultNumber());
switch (pooling_type) {
case PoolingType::k2DMin: {
pooling_op =
create_op(static_cast<linalg::PoolingNhwcMinOp*>(nullptr));
break;
}
case PoolingType::k3DMin: {
pooling_op =
create_op(static_cast<linalg::PoolingNdhwcMinOp*>(nullptr));
break;
}
case PoolingType::k2DMax: {
pooling_op =
create_op(static_cast<linalg::PoolingNhwcMaxOp*>(nullptr));
break;
}
case PoolingType::k3DMax: {
pooling_op =
create_op(static_cast<linalg::PoolingNdhwcMaxOp*>(nullptr));
break;
}
case PoolingType::k2DAdd: {
pooling_op =
create_op(static_cast<linalg::PoolingNhwcSumOp*>(nullptr));
break;
}
case PoolingType::k3DAdd: {
pooling_op =
create_op(static_cast<linalg::PoolingNdhwcSumOp*>(nullptr));
break;
}
case PoolingType::kInvalid:
return rewriter.notifyMatchFailure(op, "unknown reduction operation");
}
pooling_ops.push_back(pooling_op->getResult(0));
}
rewriter.replaceOp(op, pooling_ops);
return success();
}
};
/// Converts xla-hlo.torch_index_select op to a linalg.generic op.
struct TorchIndexSelectOpConversion
: public OpConversionPattern<mhlo::TorchIndexSelectOp> {
using OpConversionPattern<mhlo::TorchIndexSelectOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::TorchIndexSelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
int axis = static_cast<int>(op.dim());
int batch = static_cast<int>(op.batch_dims());
auto index_shaped_type = adaptor.index().getType().cast<ShapedType>();
int num_indices = static_cast<int>(index_shaped_type.getRank());
auto input_shaped_type = adaptor.input().getType().cast<ShapedType>();
if (axis < 0) axis += static_cast<int>(input_shaped_type.getRank());
if (batch < 0) batch += num_indices;
Location loc = op.getLoc();
auto result_type =
this->typeConverter->convertType(op.getResult().getType())
.cast<ShapedType>();
int rank = static_cast<int>(result_type.getRank());
// The output shape is
// `params[:axis] + indices[batch_dims:] + params[axis + 1:]`
SmallVector<Value, 4> dyn_sizes;
for (int i = 0; i < rank; ++i) {
if (!result_type.isDynamicDim(i)) continue;
if (i < axis) {
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
} else if (i < (axis + num_indices - batch)) {
int idx = i - axis + batch;
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.index(), idx));
} else {
int idx = i - (axis + num_indices - batch) + axis + 1;
dyn_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
}
}
// Generate dummy tensor to preserve slice shape information.
SmallVector<int64_t> slice_shape;
SmallVector<Value, 4> dyn_slice_sizes;
SmallVector<AffineExpr, 4> slice_exprs;
auto result_shape = result_type.getShape();
for (int i = 0; i < axis; ++i) {
slice_exprs.push_back(rewriter.getAffineDimExpr(i));
slice_shape.push_back(result_shape[i]);
if (!result_type.isDynamicDim(i)) continue;
dyn_slice_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), i));
}
for (int i = axis + num_indices - batch; i < rank; ++i) {
slice_exprs.push_back(rewriter.getAffineDimExpr(i));
slice_shape.push_back(result_shape[i]);
if (!result_type.isDynamicDim(i)) continue;
int idx = i - (axis + num_indices - batch) + axis + 1;
dyn_slice_sizes.push_back(
rewriter.create<tensor::DimOp>(loc, adaptor.input(), idx));
}
// Setup AffineMap for input tensor.
SmallVector<AffineExpr, 4> exprs;
for (int i = 0; i < batch; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(i));
}
for (int i = 0, e = num_indices - batch; i < e; ++i) {
exprs.push_back(rewriter.getAffineDimExpr(axis + i));
}
SmallVector<AffineMap, 2> indexing_maps;
indexing_maps.emplace_back(
AffineMap::get(rank, /*symbolCount=*/0, exprs, rewriter.getContext()));
indexing_maps.emplace_back(AffineMap::get(
rank, /*symbolCount=*/0, slice_exprs, rewriter.getContext()));
indexing_maps.emplace_back(rewriter.getMultiDimIdentityMap(rank));
Value slice_op = rewriter.create<linalg::InitTensorOp>(
loc, dyn_slice_sizes, slice_shape, result_type.getElementType());
Value init_op = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensors=*/ArrayRef<Type>{result_type},
/*inputs=*/ValueRange{adaptor.index(), slice_op},
/*outputs=*/init_op, indexing_maps, GetNParallelLoopsAttrs(rank),
/*bodyBuild=*/nullptr, PruneAttributeList(op));
SmallVector<Type, 4> body_arg_types;
SmallVector<Value, 2> linalg_op_args = {adaptor.index(), slice_op};
// Add a block to the region.
auto* region = &linalg_op.region();
auto* block = rewriter.createBlock(region, region->end());
for (auto block_args : linalg_op_args) {
body_arg_types.push_back(
block_args.getType().cast<ShapedType>().getElementType());
}
block->addArguments(body_arg_types,
SmallVector<Location>(body_arg_types.size(), loc));
block->addArguments(result_type.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);
Value casted_value = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), block->getArgument(0));
SmallVector<Value, 4> indices;
for (int i = 0; i < axis; ++i) {
indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
}
indices.push_back(casted_value);
for (int i = axis + num_indices - batch; i < rank; ++i) {
indices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
}
Value res =
rewriter.create<tensor::ExtractOp>(loc, adaptor.input(), indices);
rewriter.create<linalg::YieldOp>(loc, res);
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
/// This lowering encompasses the full range of the Gather operation and
/// therefore is very general and just loops over the output and calculate the
/// corresponding input index. It follows the explanation at
/// https://www.tensorflow.org/xla/operation_semantics#gather. The compiler
/// should be able to optimize that a bit, but in order to get efficient
/// lowerings, special-cases of gather should be extracted in separate
/// lowerings, and ideally encapsulated as separate ops or canonicalization
/// patterns.
struct GatherConversion : public OpConversionPattern<mhlo::GatherOp> {
using OpConversionPattern<mhlo::GatherOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::GatherOp gatherOp, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
Location loc = gatherOp.getLoc();
Value startIndices = adaptor.start_indices();
Value operand = adaptor.operand();
RankedTensorType resultType =
gatherOp.getResult().getType().dyn_cast<RankedTensorType>();
RankedTensorType startIndicesType =
startIndices.getType().dyn_cast<RankedTensorType>();
// We could actually deal with an unranked result by inferring the result
// rank, but the current reifyReturnTypes doesn't support unranked either.
if (!resultType || !startIndicesType)
return rewriter.notifyMatchFailure(gatherOp,
"unranked start indices or result");
int resultRank = resultType.getRank();
// slice_sizes has to have the same size as operand.rank, and doing it this
// way permits an unranked operand.
int operandRank = gatherOp.slice_sizes().getNumElements();
int64_t indexVectorDim = gatherOp.dimension_numbers().getIndexVectorDim();
ArrayRef<int64_t> offsetDims = gatherOp.dimension_numbers().getOffsetDims();
ArrayRef<int64_t> collapsedSliceDims =
gatherOp.dimension_numbers().getCollapsedSliceDims();
ArrayRef<int64_t> startIndexMap =
gatherOp.dimension_numbers().getStartIndexMap();
auto extractAsIndex = [&](Value input, ArrayRef<Value> index) -> Value {
return rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(),
rewriter.create<tensor::ExtractOp>(loc, input, index));
};
// We'll need these later and creating them on demand we end up with
// duplicates, which also makes lit tests really hard to write.
SmallVector<Value> constants;
for (unsigned i = 0; i < std::max(resultRank, operandRank); ++i)
constants.push_back(
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(i)));
// Create ops to calculate the dynamic dimensions of the return shape, which
// are needed for the init tensor.
SmallVector<Value> dynDimSizes;
if (!resultType.hasStaticShape()) {
SmallVector<Value> returnShapes;
if (failed(gatherOp.reifyReturnTypeShapes(rewriter, adaptor.getOperands(),
returnShapes)))
return rewriter.notifyMatchFailure(gatherOp,
"could not reify return shape");
assert(returnShapes.size() == 1);
Value returnShape = returnShapes[0];
for (int i = 0; i < resultRank; ++i)
if (resultType.isDynamicDim(i))
dynDimSizes.push_back(extractAsIndex(returnShape, constants[i]));
}
Value initOp = rewriter.create<linalg::InitTensorOp>(
loc, dynDimSizes, resultType.getShape(), resultType.getElementType());
ValueRange ins;
SmallVector<AffineMap, 1> indexingMaps(
{rewriter.getMultiDimIdentityMap(resultRank)});
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/resultType,
/*inputs=*/ins,
/*outputs=*/initOp, indexingMaps, GetNParallelLoopsAttrs(resultRank),
/*bodyBuild=*/nullptr, PruneAttributeList(gatherOp));
// Now populate the linalg generic region
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
block->addArguments(resultType.getElementType(), loc);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(block);
// Dimensions in the result that aren't offset dimensions are called batch.
SmallVector<int64_t> batchDims;
for (int dim = 0; dim < resultRank; ++dim)
if (!llvm::is_contained(offsetDims, dim)) batchDims.push_back(dim);
// Same as with the constants. Creating these all up front is easier than
// potentially getting duplicates later.
SmallVector<Value> linalgIndices;
for (unsigned i = 0; i < resultRank; ++i)
linalgIndices.push_back(rewriter.create<linalg::IndexOp>(loc, i));
// Now the complicated part. For a given output dimension we build up an
// index into the input. It's composed of two parts: the index coming from
// start_indices, and the offset from that index along the offset
// dimensions. Everything includes dimension shuffling and remapping as well
// because of the way gather is defined to allow for any-layout input by
// adding more attributes.
// The base gather index (`G` in the documentation) points to a place in
// start_indices along the batch dimensions.
SmallVector<Value> gatherIndex;
for (auto dim : batchDims) gatherIndex.push_back(linalgIndices[dim]);
SmallVector<Value> indexFromStartIndices;
for (unsigned i = 0; i < startIndexMap.size(); ++i) {
// The index along the index_vector dimension of start_indices varies.
// Basically indexFromStartIndices indexes into a "row" along
// index_vector_dim, where the row is selected by the current output
// index.
// But if index_vector_dim is equal to start_indices.rank, then
// start_indices gets a trailing 1 dimension added. So the row we're
// extracting always has length 1 and the index into it is always 0, so we
// just use the gather index directly
SmallVector<Value> gCombine(gatherIndex);
if (indexVectorDim != startIndicesType.getRank()) {
assert(indexVectorDim <= gCombine.size());
gCombine.insert(gCombine.begin() + indexVectorDim, constants[i]);
}
indexFromStartIndices.push_back(extractAsIndex(startIndices, gCombine));
}
// But then start indices are shuffled by the start index map. To make a
// full index into the operand, all missing indices are zeroes.
SmallVector<Value> remappedIndexFromIndices(operandRank, constants[0]);
for (auto& it : llvm::enumerate(startIndexMap))
remappedIndexFromIndices[it.value()] = indexFromStartIndices[it.index()];
// Now we construct the index based on the offset. First we need to remap
// the offset dimensions by dropping the collapsed indices.
SmallVector<unsigned> remappedOffsetDims;
for (unsigned i = 0; i < operandRank; ++i)
if (!llvm::is_contained(collapsedSliceDims, i))
remappedOffsetDims.push_back(i);
assert(remappedOffsetDims.size() == offsetDims.size());
// For the (remapped) offset dimensions, the index is the current index in
// the output. As before this is expanded to a full index into the operand
// by using zeroe for the missing indices.
SmallVector<Value> indexFromOffset(operandRank, constants[0]);
for (unsigned k = 0; k < offsetDims.size(); ++k)
indexFromOffset[remappedOffsetDims[k]] = linalgIndices[offsetDims[k]];
// Now we add together our two indices to get the final index into the
// operand.
SmallVector<Value> combinedIndex;
for (unsigned i = 0; i < operandRank; ++i)
combinedIndex.push_back(rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), remappedIndexFromIndices[i],
indexFromOffset[i]));
Value element =
rewriter.create<tensor::ExtractOp>(loc, operand, combinedIndex);
rewriter.create<linalg::YieldOp>(loc, element);
rewriter.replaceOp(gatherOp, linalgOp.getResults());
return success();
}
};
struct ScatterUpdateConversion : public OpConversionPattern<mhlo::ScatterOp> {
using OpConversionPattern<mhlo::ScatterOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
// Check if it is a tensor_scatter_nd_update-like op.
auto& body_ops = op.getRegion().front().getOperations();
if (body_ops.size() != 1) return failure();
auto ret_arg = body_ops.front().getOperand(0).dyn_cast<BlockArgument>();
if (!ret_arg || ret_arg.getArgNumber() != 1) return failure();
auto operand_ty = adaptor.operand().getType().dyn_cast<RankedTensorType>();
auto indices_ty =
adaptor.scatter_indices().getType().dyn_cast<RankedTensorType>();
if (!operand_ty || !indices_ty) return failure();
// Linalg operations put all the computation to the innermost loop. Since we
// also iterate over scatter_indices() with some loops, we can only check
// one scatter index in one iteration. If there are multiple indices (ie,
// the index depth is greater than 1), we don't have a way to keep the
// comparison state. E.g., if the index_depth is 2, like indices = [[0, 1]],
// we should use the update value only if (i == 0 and j == 1). However, we
// can not get both indices in one iteration unless we pack them together.
auto index_vector_dim = op.scatter_dimension_numbers().getIndexVectorDim();
if (indices_ty.getDimSize(index_vector_dim) != 1)
return rewriter.notifyMatchFailure(op, "require index depth to be 1");
if (index_vector_dim != indices_ty.getRank() - 1) {
return rewriter.notifyMatchFailure(
op, "require index_vector_dim to be the last dim");
}
// One of indices dims is index depth vector.
int64_t nloops = operand_ty.getRank() + indices_ty.getRank() - 1;
SmallVector<AffineMap, 3> indexing_maps;
{
SmallVector<AffineExpr> exprs;
for (int64_t i = 0, e = operand_ty.getRank(); i < e; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
rewriter.getContext()));
}
{
SmallVector<AffineExpr> exprs;
for (int64_t i = operand_ty.getRank(); i < nloops; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
// The index depth is 1.
exprs.push_back(rewriter.getAffineConstantExpr(0));
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
rewriter.getContext()));
exprs.pop_back();
auto update_window_dims =
op.scatter_dimension_numbers().getUpdateWindowDims();
for (auto d : update_window_dims)
exprs.push_back(rewriter.getAffineDimExpr(d));
indexing_maps.push_back(AffineMap::get(nloops, /*symbolCount=*/0, exprs,
rewriter.getContext()));
}
indexing_maps.push_back(indexing_maps.front());
auto result_ty = this->typeConverter->convertType(op.getResult().getType())
.cast<ShapedType>();
auto scatter_dims_to_operand_dims =
op.scatter_dimension_numbers().getScatterDimsToOperandDims();
assert(scatter_dims_to_operand_dims.size() == 1);
// Do not need init_tensor because we'd like to initialize the output as
// operand.
auto linalg_op = rewriter.create<linalg::GenericOp>(
op.getLoc(), /*resultTensors=*/ArrayRef<Type>{result_ty},
/*inputs=*/
ValueRange{adaptor.operand(), adaptor.scatter_indices(),
adaptor.updates()},
/*outputs=*/adaptor.operand(), indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& b, Location loc, ValueRange args) {
Value cmp_idx =
b.create<linalg::IndexOp>(loc, scatter_dims_to_operand_dims[0]);
Value idx =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), args[1]);
Value pred = b.create<arith::CmpIOp>(
loc, b.getI1Type(), arith::CmpIPredicate::eq, cmp_idx, idx);
// Use the output arg, so some update values won't be init value
// again.
Value res = b.create<arith::SelectOp>(loc, args[2].getType(), pred,
args[2], args[3]);
b.create<linalg::YieldOp>(loc, res);
},
PruneAttributeList(op));
rewriter.replaceOp(op, linalg_op.getResults());
return success();
}
};
class DotGeneralOpConversion : public OpConversionPattern<mhlo::DotGeneralOp> {
public:
using OpConversionPattern<mhlo::DotGeneralOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::DotGeneralOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const final {
if (!VerifyHloOpBufferOrTensorSemantics(op)) {
return failure();
}
// Get various dimension iterator information
mhlo::DotDimensionNumbersAttr dim_numbers = op.dot_dimension_numbers();
auto lhs_batching_dims = dim_numbers.getLhsBatchingDimensions();
auto rhs_batching_dims = dim_numbers.getRhsBatchingDimensions();
auto lhs_contracting_dims = dim_numbers.getLhsContractingDimensions();
auto rhs_contracting_dims = dim_numbers.getRhsContractingDimensions();
// Get shape information and initialize output
assert(lhs_contracting_dims.size() == rhs_contracting_dims.size() &&
"number of contracting dims must be equal");
auto num_contracting = lhs_contracting_dims.size();
auto output_type = op.getType().cast<ShapedType>();
auto target_rank = output_type.getRank();
auto total_loop_count = num_contracting + target_rank;
auto lhs_rank = adaptor.lhs().getType().cast<ShapedType>().getRank();
auto lhs_extra_dims =
lhs_rank - lhs_batching_dims.size() - lhs_contracting_dims.size();
auto rhs_rank = adaptor.rhs().getType().cast<ShapedType>().getRank();
Location loc = op.getLoc();
auto output_el_type = output_type.getElementType();
SmallVector<Value, 8> dyn_shape = GetDotGeneralOpInitTensorDynSizes(
rewriter, loc, adaptor.lhs(), adaptor.rhs(), output_type);
auto zero_attr = rewriter.getZeroAttr(output_el_type);
Value zero = rewriter.create<arith::ConstantOp>(loc, zero_attr);
auto init_tensor = GetInitTensor(rewriter, loc, output_type, dyn_shape);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, zero, init_tensor).getResult(0);
SmallVector<AffineMap, 3> indexing_maps;
// Get LHS map
{
llvm::SmallVector<AffineExpr> lhs_indices(
lhs_rank, rewriter.getAffineConstantExpr(0));
llvm::BitVector assigned_dims(lhs_rank, false);
for (const auto& i : llvm::enumerate(lhs_batching_dims)) {
lhs_indices[i.value()] = rewriter.getAffineDimExpr(i.index());
assigned_dims.set(i.value());
}
for (const auto& i : llvm::enumerate(lhs_contracting_dims)) {
assigned_dims.set(i.value());
lhs_indices[i.value()] =
rewriter.getAffineDimExpr(i.index() + target_rank);
}
for (int i = 0; i < lhs_rank; ++i) {
if (!assigned_dims[i]) {
lhs_indices[i] =
rewriter.getAffineDimExpr(i + lhs_batching_dims.size());
}
}
indexing_maps.push_back(AffineMap::get(/*dimCount=*/total_loop_count,
/*symbolCount=*/0, lhs_indices,
op->getContext()));
}
// Get RHS map
{
llvm::SmallVector<AffineExpr> rhs_indices(
rhs_rank, rewriter.getAffineConstantExpr(0));
llvm::BitVector assigned_dims(rhs_rank, false);
for (const auto& i : llvm::enumerate(rhs_batching_dims)) {
rhs_indices[i.value()] = rewriter.getAffineDimExpr(i.index());
assigned_dims.set(i.value());
}
for (const auto& i : llvm::enumerate(rhs_contracting_dims)) {
assigned_dims.set(i.value());
rhs_indices[i.value()] =
rewriter.getAffineDimExpr(i.index() + target_rank);
}
for (int i = 0; i < rhs_rank; ++i) {
if (!assigned_dims[i]) {
rhs_indices[i] = rewriter.getAffineDimExpr(
i + rhs_batching_dims.size() + lhs_extra_dims);
}
}
indexing_maps.push_back(AffineMap::get(/*dimCount=*/total_loop_count,
/*symbolCount=*/0, rhs_indices,
op->getContext()));
}
{
SmallVector<AffineExpr, 4> dim_exprs;
dim_exprs.reserve(target_rank);
for (unsigned i = 0; i < target_rank; ++i)
dim_exprs.push_back(rewriter.getAffineDimExpr(i));
indexing_maps.push_back(AffineMap::get(/*dimCount=*/total_loop_count,
/*symbolCount=*/0, dim_exprs,
op.getContext()));
}
Operation* linalg_op = rewriter.create<linalg::GenericOp>(
loc, /*resultTensorTypes=*/TypeRange{op.getType()},
/*inputs=*/ValueRange{adaptor.lhs(), adaptor.rhs()},
/*outputBuffers=*/ValueRange{zero_tensor}, indexing_maps,
GetParallelAndReductionIterators(
/*nLoops=*/total_loop_count,
/*nReduction=*/num_contracting),
[&](OpBuilder& b, Location loc, ValueRange args) {
mlir::ArithBuilder ab(b, loc);
mlir::Value mul = ab.mul(args[0], args[1]);
mlir::Value add = ab.add(mul, args[2]);
b.create<mlir::linalg::YieldOp>(loc, add);
},
PruneAttributeList(op));
rewriter.replaceOp(op, linalg_op->getResults());
return success();
}
};
struct HloLegalizeToLinalgPass
: public mhlo::HloLegalizeToLinalgPassBase<HloLegalizeToLinalgPass> {
void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<linalg::LinalgDialect, scf::SCFDialect,
complex::ComplexDialect, math::MathDialect,
memref::MemRefDialect, shape::ShapeDialect>();
}
void runOnOperation() override {
MLIRContext& ctx = getContext();
RewritePatternSet patterns(&ctx);
ConversionTarget target(ctx);
target.addLegalDialect<arith::ArithmeticDialect, complex::ComplexDialect,
linalg::LinalgDialect, math::MathDialect,
tensor::TensorDialect,
sparse_tensor::SparseTensorDialect, scf::SCFDialect,
shape::ShapeDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
mhlo::RemoveSignTypeConverter type_converter;
auto func = getOperation();
mhlo::populateHLOToLinalgConversionPattern(&ctx, type_converter, &patterns);
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
signalPassFailure();
}
}
};
} // namespace
namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context,
TypeConverter& type_converter,
RewritePatternSet* patterns) {
// clang-format off
patterns->add<
BroadcastConverter<mhlo::BroadcastOp>, ConcatenateConverter,
ConstConverterTensor, HloDynamicBroadcastInDimConverter,
HloBroadcastInDimConverter, IotaConverter<mhlo::IotaOp>,
EinsumToLinalgConverter,
IotaConverter<mhlo::DynamicIotaOp>,
PointwiseToLinalgConverter<mhlo::AbsOp>,
PointwiseToLinalgConverter<mhlo::AddOp>,
PointwiseToLinalgConverter<mhlo::AndOp>,
PointwiseToLinalgConverter<mhlo::Atan2Op>,
PointwiseToLinalgConverter<mhlo::BitcastConvertOp>,
PointwiseToLinalgConverter<mhlo::CbrtOp>,
PointwiseToLinalgConverter<mhlo::CeilOp>,
PointwiseToLinalgConverter<mhlo::ClampOp>,
PointwiseToLinalgConverter<mhlo::CompareOp>,
PointwiseToLinalgConverter<mhlo::ComplexOp>,
PointwiseToLinalgConverter<mhlo::ConvertOp>,
PointwiseToLinalgConverter<mhlo::CopyOp>,
PointwiseToLinalgConverter<mhlo::CosOp>,
PointwiseToLinalgConverter<mhlo::DivOp>,
PointwiseToLinalgConverter<mhlo::ExpOp>,
PointwiseToLinalgConverter<mhlo::Expm1Op>,
PointwiseToLinalgConverter<mhlo::FloorOp>,
PointwiseToLinalgConverter<mhlo::ImagOp>,
PointwiseToLinalgConverter<mhlo::IsFiniteOp>,
PointwiseToLinalgConverter<mhlo::LogOp>,
PointwiseToLinalgConverter<mhlo::LogisticOp>,
PointwiseToLinalgConverter<mhlo::Log1pOp>,
PointwiseToLinalgConverter<mhlo::MaxOp>,
PointwiseToLinalgConverter<mhlo::MinOp>,
PointwiseToLinalgConverter<mhlo::MulOp>,
PointwiseToLinalgConverter<mhlo::NegOp>,
PointwiseToLinalgConverter<mhlo::NotOp>,
PointwiseToLinalgConverter<mhlo::OrOp>,
PointwiseToLinalgConverter<mhlo::PopulationCountOp>,
PointwiseToLinalgConverter<mhlo::PowOp>,
PointwiseToLinalgConverter<mhlo::RealOp>,
PointwiseToLinalgConverter<mhlo::RemOp>,
PointwiseToLinalgConverter<mhlo::RoundOp>,
PointwiseToLinalgConverter<mhlo::RsqrtOp>,
PointwiseToLinalgConverter<mhlo::SelectOp>,
PointwiseToLinalgConverter<mhlo::ShiftLeftOp>,
PointwiseToLinalgConverter<mhlo::ShiftRightArithmeticOp>,
PointwiseToLinalgConverter<mhlo::ShiftRightLogicalOp>,
PointwiseToLinalgConverter<mhlo::SignOp>,
PointwiseToLinalgConverter<mhlo::SinOp>,
PointwiseToLinalgConverter<mhlo::SqrtOp>,
PointwiseToLinalgConverter<mhlo::SubOp>,
PointwiseToLinalgConverter<mhlo::TanhOp>,
PointwiseToLinalgConverter<mhlo::XorOp>,
RealDynamicSliceConverter,
ReshapeOpConverter,
ReverseConverter,
SliceConverter,
DynamicSliceConverter,
DynamicUpdateSliceConverter,
TransposeConverter<mhlo::TransposeOp>,
NormalConvOpConversion,
DepthwiseConvOpConversion,
ReduceConversion,
ReduceWindowOpOnTensorsGenericConversion,
ReduceWindowOpConversion,
RngUniformConversion,
ScatterUpdateConversion,
GatherConversion,
TorchIndexSelectOpConversion,
PadOpConversion>(type_converter, context);
patterns->add<
DotOpConversion<DotOperationType::kMatrixMatrix, linalg::MatmulOp>,
DotOpConversion<DotOperationType::kMatrixVector, linalg::MatvecOp>,
DotOpConversion<DotOperationType::kVectorMatrix, linalg::VecmatOp>,
DotOpConversion<DotOperationType::kVectorDot, linalg::DotOp>,
DotGeneralBatchMatMulOpConversion>(type_converter, context,
PatternBenefit(2));
// clang-format on
patterns->add<DotGeneralOpConversion>(type_converter, context,
PatternBenefit(1));
patterns->add<ReduceRegionXLAOpConversion<mhlo::AddOp>,
ReduceRegionXLAOpConversion<mhlo::AndOp>,
ReduceRegionXLAOpConversion<mhlo::CompareOp>,
ReduceRegionXLAOpConversion<mhlo::MaxOp>,
ReduceRegionXLAOpConversion<mhlo::MinOp>,
ReduceRegionXLAOpConversion<mhlo::MulOp>,
ReduceRegionXLAOpConversion<mhlo::OrOp>,
ReduceRegionXLAOpConversion<mhlo::SelectOp>,
ReduceRegionReturnOpConversion>(context, PatternBenefit(1000));
}
std::unique_ptr<OperationPass<func::FuncOp>> createLegalizeHloToLinalgPass() {
return std::make_unique<HloLegalizeToLinalgPass>();
}
std::unique_ptr<TypeConverter> createHloToLinalgSignedIntegerConverter() {
return std::make_unique<RemoveSignTypeConverter>();
}
} // namespace mhlo
} // namespace mlir