blob: 16a73b1e3569755a6186b594f7b721b838e6c933 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
namespace mhlo {
namespace {
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
auto start_indices = gather.start_indices();
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
if (!start_indices_ty.hasRank()) {
return rewriter.notifyMatchFailure(gather, "unranked start_indices");
}
auto operand = gather.operand();
auto operand_ty = operand.getType().cast<ShapedType>();
if (!operand_ty.hasRank()) {
return rewriter.notifyMatchFailure(gather, "unranked operand");
}
int64_t index_vector_dim =
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
// We can use torch_index_select if the last dimension represents the
// gather indices.
auto dimension_numbers = gather.dimension_numbers();
if (dimension_numbers.getIndexVectorDim() != index_vector_dim) {
return rewriter.notifyMatchFailure(
gather, "index_vector_dim not last dimension of start_indices");
}
// Index select only works across a single dimension.
if (!start_indices_ty.getShape().empty() &&
start_indices_ty.getShape().back() != 1) {
return rewriter.notifyMatchFailure(
gather, "start_indices index vector dimension not 1");
}
// Only support the default case for start_index_map.
if (dimension_numbers.getStartIndexMap().size() != 1 ||
dimension_numbers.getStartIndexMap()[0] != 0) {
return rewriter.notifyMatchFailure(gather, "start_index_map != [0]");
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return rewriter.notifyMatchFailure(gather, "unranked result");
}
// Offset dimensions should be the defaults.
if (dimension_numbers.getOffsetDims().size() !=
result_ty.getRank() - index_vector_dim) {
return rewriter.notifyMatchFailure(
gather, "offset_dims.size not operand rank minus index_vector_dim");
}
for (const auto &it : llvm::enumerate(dimension_numbers.getOffsetDims())) {
if ((it.index() + index_vector_dim) != it.value()) {
return rewriter.notifyMatchFailure(
gather, "offset_dims != [index_vector_dim, result.rank)");
}
}
for (const auto &it :
llvm::enumerate(gather.slice_sizes().getValues<APInt>())) {
// First shape value must be 1.
if (it.index() == 0) {
if (it.value().getSExtValue() != 1) {
return rewriter.notifyMatchFailure(gather, "slice_size[0] != 1");
}
continue;
}
// The gather needs to index the entire slice for each other dimension.
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
return rewriter.notifyMatchFailure(
gather, "slice_size doesn't match operand dimension");
}
}
llvm::SmallVector<int64_t, 4> index_select_shape =
llvm::to_vector<4>(start_indices_ty.getShape());
for (auto dim : operand_ty.getShape().drop_front()) {
index_select_shape.push_back(dim);
}
if (dimension_numbers.getCollapsedSliceDims().size() != 1 ||
dimension_numbers.getCollapsedSliceDims()[0] != 0) {
return rewriter.notifyMatchFailure(gather, "collapsed_slice_dims != [0]");
}
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
gather.getLoc(),
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
torch_index_select);
return success();
}
};
struct LegalizeGatherToTorchIndexSelectPass
: public LegalizeGatherToTorchIndexSelectPassBase<
LegalizeGatherToTorchIndexSelectPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void PopulateGatherToTorchIndexSelectPatterns(mlir::MLIRContext *context,
RewritePatternSet *patterns) {
patterns->add<GatherIsTorchIndexSelect>(context);
}
std::unique_ptr<OperationPass<func::FuncOp>>
createLegalizeGatherToTorchIndexSelectPass() {
return std::make_unique<LegalizeGatherToTorchIndexSelectPass>();
}
} // namespace mhlo
} // namespace mlir