| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h" |
| |
| #include <climits> |
| #include <cstdint> |
| #include <utility> |
| |
| #include "absl/memory/memory.h" |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/Support/Casting.h" |
| #include "llvm/Support/Debug.h" |
| #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project |
| #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| #include "mlir/IR/Attributes.h" // from @llvm-project |
| #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
| #include "mlir/IR/OpImplementation.h" // from @llvm-project |
| #include "mlir/IR/PatternMatch.h" // from @llvm-project |
| #include "mlir/Pass/Pass.h" // from @llvm-project |
| #include "mlir/Support/LLVM.h" // from @llvm-project |
| #include "mlir/Support/LogicalResult.h" // from @llvm-project |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h" |
| #include "tensorflow/core/util/matmul_bcast.h" |
| |
| namespace mlir { |
| namespace TF { |
| |
| namespace { |
| |
| template <typename BatchMatMulOpType> |
| class ConvertTFBatchMatMulOp : public OpRewritePattern<BatchMatMulOpType> { |
| using OpRewritePattern<BatchMatMulOpType>::OpRewritePattern; |
| |
| static TF::ReshapeOp createReshapeOp(Value value, ArrayRef<int64_t> shape, |
| Type element_type, Location loc, |
| PatternRewriter& rewriter); |
| |
| static std::vector<Value> sliceInput(Value value, int batch_size, |
| Location loc, PatternRewriter& rewriter); |
| |
| LogicalResult matchAndRewrite(BatchMatMulOpType op, |
| PatternRewriter& rewriter) const override; |
| }; |
| |
| // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out |
| // of the inputs, matmul them individually, then stack them all back together at |
| // the end. |
| struct UnrollBatchMatMulPass |
| : public UnrollBatchMatMulPassBase<UnrollBatchMatMulPass> { |
| void runOnOperation() override; |
| }; |
| |
| void UnrollBatchMatMulPass::runOnOperation() { |
| RewritePatternSet patterns(&getContext()); |
| auto func = getOperation(); |
| PopulateUnrollTfBatchMatMul(&getContext(), patterns); |
| (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); |
| } |
| |
| } // namespace |
| |
| template <typename BatchMatMulOpType> |
| TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp( |
| Value value, ArrayRef<int64_t> shape, Type element_type, Location loc, |
| PatternRewriter& rewriter) { |
| int64_t shape_rank = shape.size(); |
| auto shape_spec_type = |
| RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64)); |
| Type resultType = RankedTensorType::get(shape, element_type); |
| auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape); |
| auto shape_tensor = |
| rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr); |
| return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value, |
| /*shape=*/shape_tensor); |
| } |
| |
| template <typename BatchMatMulOpType> |
| std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput( |
| Value value, int batch_size, Location loc, PatternRewriter& rewriter) { |
| RankedTensorType tensorType = value.getType().cast<RankedTensorType>(); |
| Type element_type = tensorType.getElementType(); |
| |
| int rank = tensorType.getShape().size(); |
| int num_rows = tensorType.getShape()[rank - 2]; |
| int num_cols = tensorType.getShape()[rank - 1]; |
| |
| std::vector<Value> sliced; |
| |
| if (batch_size == 1) { |
| // Batch size is 1, no splitting is required |
| // Squeeze the batch dimension, i.e. reshape |
| // [1, num_rows, num_cols] -> [num_rows, num_cols] |
| auto reshape_op = createReshapeOp(value, {num_rows, num_cols}, element_type, |
| loc, rewriter); |
| sliced.emplace_back(reshape_op.output()); |
| } else { |
| // Reshape to rank-3 tensor with first dimension as the batch size. |
| auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols}, |
| element_type, loc, rewriter); |
| |
| // Create a constant op for the split axis (=0) |
| auto split_dimension_type = |
| RankedTensorType::get({}, rewriter.getIntegerType(32)); |
| auto split_dimension_attr = DenseElementsAttr::get(split_dimension_type, 0); |
| auto split_dimension_op = rewriter.create<TF::ConstOp>( |
| loc, split_dimension_type, split_dimension_attr); |
| |
| // Split along each batch. |
| SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols}; |
| Type slice_result_type = RankedTensorType::get(slice_size, element_type); |
| llvm::SmallVector<Type, 4> output_types(batch_size, slice_result_type); |
| auto split_op = rewriter.create<TF::SplitOp>( |
| loc, output_types, split_dimension_op.output(), reshape_op.output()); |
| |
| // Squeeze each batch, i.e. reshape |
| // [1, num_rows, num_cols] -> [num_rows, num_cols] |
| for (const auto& split_value : split_op.output()) { |
| auto reshape_op = createReshapeOp(split_value, {num_rows, num_cols}, |
| element_type, loc, rewriter); |
| |
| sliced.emplace_back(reshape_op.output()); |
| } |
| } |
| return sliced; |
| } |
| |
| template <typename BatchMatMulOpType> |
| LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite( |
| BatchMatMulOpType op, PatternRewriter& rewriter) const { |
| Value input_lhs = op.x(); |
| Value input_rhs = op.y(); |
| |
| if (!input_lhs.getType().isa<RankedTensorType>()) { |
| // LHS must be a ranked tensor type |
| return failure(); |
| } |
| if (!input_rhs.getType().isa<RankedTensorType>()) { |
| // RHS must be a ranked tensor type |
| return failure(); |
| } |
| |
| auto lhs_type = input_lhs.getType().cast<RankedTensorType>(); |
| auto rhs_type = input_rhs.getType().cast<RankedTensorType>(); |
| |
| // Skip int8 x int8 => int32. |
| if (lhs_type.getElementType().isInteger(8) && |
| rhs_type.getElementType().isInteger(8)) { |
| return rewriter.notifyMatchFailure(op, |
| "skip unrolling for int8 BatchMatMulV3"); |
| } |
| |
| auto element_type = lhs_type.getElementType(); |
| |
| if (element_type != rhs_type.getElementType()) { |
| // The element type of LHS must be the same with element type of RHS |
| return failure(); |
| } |
| |
| std::vector<int64_t> lhs_shape = lhs_type.getShape(); |
| std::vector<int64_t> rhs_shape = rhs_type.getShape(); |
| |
| Location loc = op.getLoc(); |
| |
| // Ensure that input ranks are at least 2. |
| const int lhs_dims = lhs_shape.size(); |
| const int rhs_dims = rhs_shape.size(); |
| if (lhs_dims < 2 || rhs_dims < 2) { |
| // Both inputs must have rank >= 2 |
| return failure(); |
| } |
| |
| // Replace the last 2 dimensions of LHS and RHS if necessary. |
| // The actual transpose is done by MatMulOp. |
| if (op.adj_x()) { |
| std::swap(lhs_shape[lhs_dims - 1], lhs_shape[lhs_dims - 2]); |
| } |
| if (op.adj_y()) { |
| std::swap(rhs_shape[rhs_dims - 1], rhs_shape[rhs_dims - 2]); |
| } |
| |
| const int rows = lhs_shape[lhs_dims - 2]; |
| const int cols = rhs_shape[rhs_dims - 1]; |
| |
| if (lhs_shape[lhs_dims - 1] != rhs_shape[rhs_dims - 2]) { |
| // Input dimensions must be compatible for multiplication. |
| return failure(); |
| } |
| |
| const auto matmul_type = RankedTensorType::get({rows, cols}, element_type); |
| |
| if (lhs_dims == 2 && rhs_dims == 2) { |
| // When both inputs are matrices, just replace the op with a matmul op. |
| rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, matmul_type, |
| /*a=*/input_lhs, |
| /*b=*/input_rhs, |
| /*transpose_a=*/op.adj_x(), |
| /*transpose_b=*/op.adj_y()); |
| return success(); |
| } |
| |
| // Input dimensions must be defined. MatMulBCast does not support partial |
| // shapes. |
| for (auto dim : lhs_shape) { |
| if (dim == -1) { |
| return failure(); |
| } |
| } |
| for (auto dim : rhs_shape) { |
| if (dim == -1) { |
| return failure(); |
| } |
| } |
| // Ensure that batch shapes are broadcastable. |
| tensorflow::MatMulBCast bcast( |
| absl::InlinedVector<int64_t, 4>(lhs_shape.begin(), lhs_shape.end()), |
| absl::InlinedVector<int64_t, 4>(rhs_shape.begin(), rhs_shape.end())); |
| |
| if (!bcast.IsValid()) { |
| // Input batch dimensions must be broadcastable |
| return failure(); |
| } |
| |
| // Compute slices for each batch in the LHS and RHS. |
| std::vector<Value> sliced_lhs = |
| sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter); |
| std::vector<Value> sliced_rhs = |
| sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter); |
| |
| // Compute (single batch) MatMul for each output batch. |
| std::vector<Value> matmuls; |
| matmuls.reserve(bcast.output_batch_size()); |
| for (int batch_idx : llvm::seq<int>(0, bcast.output_batch_size())) { |
| int lhs_batch_idx, rhs_batch_idx; |
| if (bcast.IsBroadcastingRequired()) { |
| lhs_batch_idx = bcast.x_batch_indices()[batch_idx]; |
| rhs_batch_idx = bcast.y_batch_indices()[batch_idx]; |
| } else { |
| lhs_batch_idx = batch_idx; |
| rhs_batch_idx = batch_idx; |
| } |
| auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type, |
| /*a=*/sliced_lhs[lhs_batch_idx], |
| /*b=*/sliced_rhs[rhs_batch_idx], |
| /*transpose_a=*/op.adj_x(), |
| /*transpose_b=*/op.adj_y()); |
| matmuls.emplace_back(matmul.product()); |
| } |
| |
| // Combine the result of each individual MatMul into a rank-3 tensor. |
| Type packed_type = RankedTensorType::get( |
| {bcast.output_batch_size(), rows, cols}, element_type); |
| const auto axis = rewriter.getI64IntegerAttr(0); |
| auto pack_op = |
| rewriter.create<TF::PackOp>(loc, packed_type, /*values=*/matmuls, axis); |
| |
| // Reshape the rank-3 tensor into the correct output shape. |
| const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes(); |
| std::vector<int64_t> result_shape(result_batch_shape.begin(), |
| result_batch_shape.end()); |
| result_shape.push_back(rows); |
| result_shape.push_back(cols); |
| |
| auto reshape_op = createReshapeOp(pack_op.output(), result_shape, |
| element_type, loc, rewriter); |
| rewriter.replaceOp(op, reshape_op.output()); |
| return success(); |
| } |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> CreateUnrollBatchMatMulPassPass() { |
| return std::make_unique<UnrollBatchMatMulPass>(); |
| } |
| |
| } // namespace TF |
| } // namespace mlir |
| |
| void mlir::TF::PopulateUnrollTfBatchMatMul(MLIRContext* context, |
| RewritePatternSet& patterns) { |
| patterns.add<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>, |
| ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>, |
| ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(context); |
| } |