| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // This file implements logic for lowering TensorFlow dialect to XLA dialect. |
| |
| #include <cstdint> |
| #include <numeric> |
| |
| #include "llvm/ADT/Optional.h" |
| #include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir |
| #include "mlir/IR/Attributes.h" // TF:local_config_mlir |
| #include "mlir/IR/Diagnostics.h" // TF:local_config_mlir |
| #include "mlir/IR/MLIRContext.h" // TF:local_config_mlir |
| #include "mlir/IR/Matchers.h" // TF:local_config_mlir |
| #include "mlir/IR/Module.h" // TF:local_config_mlir |
| #include "mlir/IR/Operation.h" // TF:local_config_mlir |
| #include "mlir/IR/PatternMatch.h" // TF:local_config_mlir |
| #include "mlir/IR/StandardTypes.h" // TF:local_config_mlir |
| #include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir |
| #include "mlir/IR/Types.h" // TF:local_config_mlir |
| #include "mlir/Pass/Pass.h" // TF:local_config_mlir |
| #include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir |
| #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
| #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" |
| #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" |
| #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
| |
| using namespace mlir; |
| |
| namespace { |
| struct LegalizeTF : public FunctionPass<LegalizeTF> { |
| /// Performs the lowering to XLA dialect. |
| void runOnFunction() override; |
| }; |
| } // end anonymous namespace |
| |
| std::unique_ptr<mlir::OpPassBase<mlir::FuncOp>> |
| mlir::xla_hlo::createLegalizeTFPass() { |
| return std::make_unique<LegalizeTF>(); |
| } |
| |
| /// Returns if the given TF data format string is the default format. |
| static bool isDefaultDataFormat(StringRef format) { return format == "NHWC"; } |
| |
| /// Returns the feature dimension for the given format and input type. |
| static size_t getFeatureDimension(StringAttr format, |
| RankedTensorType inputType) { |
| return isDefaultDataFormat(format.getValue()) ? inputType.getRank() - 1 : 1; |
| } |
| |
| // Returns 1D 64-bit dense elements attribute with the given values. |
| static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values, |
| Builder *builder) { |
| RankedTensorType ty = builder->getTensorType( |
| {static_cast<int64_t>(values.size())}, builder->getIntegerType(64)); |
| return DenseElementsAttr::get<int64_t>(ty, values) |
| .cast<DenseIntElementsAttr>(); |
| } |
| |
| static IntegerAttr GetHLOAxisFromTFAxis(ElementsAttr attr, int64_t rank, |
| Builder *b) { |
| SmallVector<uint64_t, 1> index(attr.getType().getRank(), 0); |
| int64_t axis = attr.getValue<IntegerAttr>(index).getInt(); |
| if (axis < 0) { |
| axis += rank; |
| } |
| return b->getI64IntegerAttr(axis); |
| } |
| |
| // If `value` is an IntegerAttr, returns the integer value for the HLO axis |
| // corresponding to the tensorflow axis. In particular, the tensorflow axis can |
| // be negative, in which case, the corresponding HLO axis is |
| // (axis + rank-of-the-tensor). |
| static llvm::Optional<int64_t> GetIntegerHLOAxisFromTFAxis(Value *value, |
| int64_t rank) { |
| DenseIntElementsAttr attrs; |
| if (!matchPattern(value, m_Constant(&attrs)) || |
| attrs.getType().getRank() != 0) { |
| return llvm::None; |
| } |
| int64_t axis = attrs.getValue<IntegerAttr>({}).getInt(); |
| return axis < 0 ? axis + rank : axis; |
| } |
| |
| /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining |
| /// the shape of the input value. |
| static xla_hlo::ConvertOp CastElementsToI64(Location loc, Value *value, |
| PatternRewriter *rewriter) { |
| auto type = value->getType().cast<RankedTensorType>(); |
| assert(type && "CastElementsToI64 requires a shaped tensor as input."); |
| ArrayRef<int64_t> shape = type.getShape(); |
| auto i64_type = rewriter->getTensorType(shape, rewriter->getIntegerType(64)); |
| return rewriter->create<xla_hlo::ConvertOp>(loc, i64_type, value); |
| } |
| |
| // Returns minimum value for the given int or float element type. |
| static xla_hlo::ConstOp GetMinValueForType(Type ty, Location loc, |
| PatternRewriter *rewriter) { |
| RankedTensorType scalar_ty = rewriter->getTensorType({}, ty); |
| |
| DenseElementsAttr attr; |
| if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) { |
| APFloat neg_inf = |
| APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true); |
| attr = DenseElementsAttr::get(scalar_ty, neg_inf); |
| } else { |
| auto int_ty = ty.cast<IntegerType>(); |
| APInt min_val = APInt::getSignedMinValue(int_ty.getWidth()); |
| attr = DenseElementsAttr::get(scalar_ty, min_val); |
| } |
| return rewriter->create<xla_hlo::ConstOp>(loc, attr); |
| } |
| |
| // Returns an integer constant for the given int or float element type. |
| static xla_hlo::ConstOp GetScalarForType(Type ty, Location loc, |
| int64_t raw_value, |
| PatternRewriter *rewriter) { |
| RankedTensorType scalar_ty = rewriter->getTensorType({}, ty); |
| |
| DenseElementsAttr attr; |
| if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) { |
| APFloat value(float_ty.getFloatSemantics(), raw_value); |
| attr = DenseElementsAttr::get(scalar_ty, value); |
| } else { |
| auto int_ty = ty.cast<IntegerType>(); |
| APInt value(int_ty.getWidth(), raw_value, true); |
| attr = DenseElementsAttr::get(scalar_ty, value); |
| } |
| return rewriter->create<xla_hlo::ConstOp>(loc, attr); |
| } |
| |
| // Builds body for reduce op by using the using the template binary op as the |
| // reducer op. |
| template <typename Op> |
| static void BuildReduceBody(Type element_type, Region *body, |
| OpBuilder *builder) { |
| OpBuilder::InsertionGuard guard(*builder); |
| Block *block = builder->createBlock(body); |
| |
| // Block arguments are scalars of the given element type. |
| Type type = builder->getTensorType(/*shape=*/{}, element_type); |
| block->addArguments({type, type}); |
| |
| Location loc = body->getLoc(); |
| auto reducer = builder->create<Op>(loc, type, block->getArgument(0), |
| block->getArgument(1), |
| /*broadcast_dimensions=*/nullptr); |
| builder->create<xla_hlo::ReturnOp>(loc, reducer.getResult()); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // BatchNorm op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| static IntegerAttr getFeatureDimensionAttr(Builder &b, StringAttr format, |
| Value *input) { |
| return b.getI64IntegerAttr( |
| getFeatureDimension(format, input->getType().cast<RankedTensorType>())); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Bias op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| /// Return a 1D DenseIntElementsAttr for the feature dimension of a BiasAdd. |
| static DenseIntElementsAttr getBiasFeatureDimension(Builder &b, |
| StringAttr format, |
| Value *input) { |
| auto inputType = input->getType().cast<RankedTensorType>(); |
| size_t featureDim = getFeatureDimension(format, inputType); |
| RankedTensorType type = b.getTensorType(1, b.getIntegerType(64)); |
| return DenseIntElementsAttr::get(type, featureDim) |
| .cast<DenseIntElementsAttr>(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Binary op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| /// Get a constant splat for the given value type. |
| template <typename T> |
| static ElementsAttr getSplat(Builder &b, Value *val, T constant) { |
| auto valType = val->getType().cast<TensorType>(); |
| auto valElementType = valType.getElementType(); |
| |
| // Handle integer elements. |
| Attribute elementAttr; |
| if (valElementType.isa<IntegerType>()) |
| elementAttr = b.getIntegerAttr(valElementType, constant); |
| else if (valElementType.isa<FloatType>()) |
| elementAttr = b.getFloatAttr(valElementType, constant); |
| else |
| llvm_unreachable("unhandled element type"); |
| return DenseElementsAttr::get(valType, elementAttr); |
| } |
| |
| // Returns whether the two values are guaranteed to be broadcastable to the |
| // same shape, this broadcasts size 1 tensors up to any rank. Dynamic dimensions |
| // must be broadcasted with a size 1 tensor or another dynamic dimension. |
| // Returns false on rankless. |
| static bool AreBroadcastCompatible(Value *x, Value *y) { |
| auto x_rankless = x->getType().dyn_cast<RankedTensorType>(); |
| auto y_rankless = y->getType().dyn_cast<RankedTensorType>(); |
| if (!x_rankless || !y_rankless) { |
| return false; |
| } |
| |
| // Check that the shapes can be broadcasted. |
| auto shape_x = x_rankless.getShape(); |
| auto shape_y = y_rankless.getShape(); |
| |
| int rank_diff = shape_x.size() - shape_y.size(); |
| int offset_x = rank_diff > 0 ? rank_diff : 0; |
| int offset_y = rank_diff < 0 ? -rank_diff : 0; |
| for (int i = 0, s = std::min(shape_x.size(), shape_y.size()); i < s; i++) { |
| int index_x = i + offset_x; |
| int index_y = i + offset_y; |
| if ((shape_x[index_x] == -1 && shape_y[index_y] != 1) || |
| (shape_y[index_y] == -1 && shape_x[index_x] != 1)) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| static DenseIntElementsAttr getBroadcastDimensionsAttr(Builder &b, Value *x, |
| Value *y) { |
| TensorType xType = x->getType().dyn_cast<RankedTensorType>(); |
| TensorType yType = y->getType().dyn_cast<RankedTensorType>(); |
| if (xType == yType || !xType || !yType) return {}; |
| |
| // If the shapes have the same rank, then there is nothing to do. |
| auto xRank = xType.getRank(), yRank = yType.getRank(); |
| if (xRank == yRank) return {}; |
| |
| // Otherwise if the ranks of the inputs don't match, TensorFlow automatically |
| // reshapes the smaller by padding with dimensions of size 1 as a prefix. In |
| // other words to pad a 5-vector to a 3-dimensional tensor it is reshaped to |
| // have shape [1,1,5]. XLA's automatic broadcast code is able to broadcast |
| // from lower to higher rank, but doesn't assume you want to pad as a prefix |
| // of the dimensions, and instead needs to be told which dimensions of the |
| // higher rank tensor to match to the lower rank tensor. |
| auto maxRank = std::max(xRank, yRank); |
| auto minRank = std::min(xRank, yRank); |
| |
| // Match the lower rank tensor along the larger-numbered dimensions of the |
| // higher rank tensor. |
| SmallVector<int64_t, 4> broadcastDimensions(minRank); |
| std::iota(broadcastDimensions.begin(), broadcastDimensions.end(), |
| maxRank - minRank); |
| |
| RankedTensorType type = b.getTensorType({minRank}, b.getIntegerType(64)); |
| return DenseIntElementsAttr::get<int64_t>(type, broadcastDimensions) |
| .cast<DenseIntElementsAttr>(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Softmax op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| // Returns a 1-d i64 elements attribute populated with numbers from start to |
| // end, excluding. |
| static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, |
| Builder *builder) { |
| int size = end - start; |
| |
| SmallVector<int64_t, 4> vals; |
| vals.resize(size); |
| std::iota(vals.begin(), vals.end(), start); |
| |
| TensorType ty = builder->getTensorType({size}, builder->getIntegerType(64)); |
| return DenseIntElementsAttr::get<int64_t>(ty, vals) |
| .cast<DenseIntElementsAttr>(); |
| } |
| |
| // Returns the type to use for accumulating the given type. |
| static Type GetAccumulationType(Type ty) { |
| // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from |
| // repeated floating point additions. |
| return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty; |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // ArgMax/ArgMin op utilities. |
| //===----------------------------------------------------------------------===// |
| |
| static void BuildArgMinMaxReductionBody(Type input_element_type, |
| Type index_element_type, |
| StringRef direction, Region *body, |
| OpBuilder *builder) { |
| OpBuilder::InsertionGuard insertion_point_gurad(*builder); |
| |
| Type input_type = builder->getTensorType(/*shape=*/{}, input_element_type); |
| Type index_type = builder->getTensorType(/*shape=*/{}, index_element_type); |
| Block *block = builder->createBlock(body); |
| block->addArguments({input_type, index_type, input_type, index_type}); |
| |
| Location loc = body->getLoc(); |
| Type compare_type = |
| builder->getTensorType(/*shape=*/{}, builder->getIntegerType(1)); |
| StringAttr compare_direction = |
| StringAttr::get(direction, builder->getContext()); |
| Value *compare = builder->create<xla_hlo::CompareOp>( |
| loc, compare_type, block->getArgument(0), block->getArgument(2), |
| /*broadcast_dimensions=*/nullptr, compare_direction); |
| |
| Value *selected_input = builder->create<xla_hlo::SelectOp>( |
| loc, input_type, compare, block->getArgument(0), block->getArgument(2)); |
| Value *selected_index = builder->create<xla_hlo::SelectOp>( |
| loc, index_type, compare, block->getArgument(1), block->getArgument(3)); |
| |
| Value *return_values[] = {selected_input, selected_index}; |
| builder->create<xla_hlo::ReturnOp>(loc, return_values); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Op converters. |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace xla { |
| namespace { |
| |
| // Converts MaxPool op to HLO ReduceWindow op by setting appropriate window |
| // dimensions with max as the reduction function. |
| // |
| // Sample result for VALID padding mode: |
| // |
| // %init = constant dense<...> : tensor<i32> |
| // %max_pool = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"] |
| // {window_dimensions = ..., window_strides = ... } |
| // |
| class ConvertMaxPoolOp : public OpRewritePattern<TF::MaxPoolOp> { |
| public: |
| explicit ConvertMaxPoolOp(MLIRContext *context) |
| : OpRewritePattern<TF::MaxPoolOp>(context, 1) {} |
| |
| PatternMatchResult matchAndRewrite(TF::MaxPoolOp op, |
| PatternRewriter &rewriter) const override { |
| // TODO(hinsu): Support 'SAME' padding mode. |
| if (op.padding() != "VALID") return matchFailure(); |
| |
| Type element_type = |
| op.input()->getType().cast<TensorType>().getElementType(); |
| if (!element_type.isIntOrFloat()) return matchFailure(); |
| Location loc = op.getLoc(); |
| xla_hlo::ConstOp init = GetMinValueForType(element_type, loc, &rewriter); |
| |
| auto get_elements_attr = [&](ArrayAttr attr) { |
| RankedTensorType ty = rewriter.getTensorType( |
| static_cast<int64_t>(attr.size()), rewriter.getIntegerType(64)); |
| return DenseElementsAttr::get(ty, attr.getValue()) |
| .cast<DenseIntElementsAttr>(); |
| }; |
| |
| auto reduce = rewriter.create<xla_hlo::ReduceWindowOp>( |
| loc, op.getType(), op.input(), init.getResult(), |
| get_elements_attr(op.ksize()), get_elements_attr(op.strides()), |
| /*base_dilations=*/DenseIntElementsAttr(), |
| /*window_dilations=*/DenseIntElementsAttr(), |
| /*paddings=*/DenseIntElementsAttr()); |
| BuildReduceBody<xla_hlo::MaxOp>(element_type, &reduce.body(), &rewriter); |
| |
| rewriter.replaceOp(op, reduce.getResult(0)); |
| return matchSuccess(); |
| } |
| }; |
| |
| // Converts Sigmoid op to HLO ops computing sigmoid with the following formula: |
| // |
| // sigmoid = add(mul(tanh(mul(logits, 0.5)), 0.5), 0.5) |
| // |
| // Sample result with 2-d f16 inputs with B batches of with N elements each. |
| // |
| // // Create an array of 0.5 the shape of the input array. |
| // %half = xla_hlo.constant dense<5.000000e-01> : tensor<f32> |
| // %half_array = "xla_hlo.broadcast"(half) |
| // {broadcast_sizes = dense<2> : tensor<1xi64>} |
| // : (tensor<f32>) -> tensor<2xf32> |
| // |
| // // Compute Tanh of half the logits of the values. |
| // %halved_logits = xla_hlo.mul %logits, %half_array : tensor<2xf32> |
| // %tanh = "xla_hlo.tanh"(%halved_logits) : (tensor<2xf32>) -> tensor<2xf32> |
| // |
| // // Have the result of Tanh and add 0.5. |
| // %halved_tanh = xla_hlo.mul %tanh, %half : tensor<2xf32> |
| // %sigmoid = xla_hlo.add %halved_tanh, %half : tensor<2xf32> |
| // |
| class ConvertSigmoidOp : public OpRewritePattern<TF::SigmoidOp> { |
| public: |
| explicit ConvertSigmoidOp(MLIRContext *context) |
| : OpRewritePattern<TF::SigmoidOp>(context, 1) {} |
| |
| PatternMatchResult matchAndRewrite(TF::SigmoidOp op, |
| PatternRewriter &rewriter) const override { |
| auto operand = op.getOperand(); |
| |
| auto scalar_one = rewriter.create<xla_hlo::ConstOp>( |
| op.getLoc(), |
| rewriter.getFloatAttr(getElementTypeOrSelf(operand->getType()), 0.5)); |
| |
| auto shaped_type = operand->getType().cast<ShapedType>(); |
| auto constant_ones = rewriter.create<xla_hlo::BroadcastOp>( |
| op.getLoc(), shaped_type, scalar_one, |
| rewriter |
| .getDenseIntElementsAttr( |
| rewriter.getTensorType({shaped_type.getRank()}, |
| rewriter.getIntegerType(64)), |
| shaped_type.getShape()) |
| .cast<DenseIntElementsAttr>()); |
| |
| auto scaled_input = rewriter.create<xla_hlo::MulOp>( |
| op.getLoc(), operand->getType(), operand, constant_ones, |
| DenseIntElementsAttr()); |
| auto tanh_op = rewriter.create<xla_hlo::TanhOp>( |
| op.getLoc(), operand->getType(), scaled_input); |
| auto mul_op = rewriter.create<xla_hlo::MulOp>( |
| op.getLoc(), operand->getType(), tanh_op, constant_ones, |
| /*DenseIntElementsAttr=*/DenseIntElementsAttr()); |
| auto add_op = rewriter.create<xla_hlo::AddOp>( |
| op.getLoc(), operand->getType(), mul_op, constant_ones, |
| /*DenseIntElementsAttr=*/DenseIntElementsAttr()); |
| |
| rewriter.replaceOp(op, add_op.getResult()); |
| return matchSuccess(); |
| } |
| }; |
| |
| // Converts Softmax and LogSoftmax to HLO ops, computing softmax with the |
| // following formulas: |
| // |
| // softmax = div(exp(logits), sum(exp(logits))) |
| |
| // log_softmax = sub(logits, log(sum(exp(logits)))) |
| // |
| // Sample result with 2-d f16 inputs with B batches of with N elements each. |
| // |
| // // Subtract each element by their batches' max to improve numerical |
| // // stability. |
| // %neg_infinity = constant dense<0xFF800000> : tensor<f16> |
| // %max = "xla_hlo.reduce"(%input, %neg_infinity) ["xla_hlo.max"] |
| // {dimensions = 1} |
| // : (tensor<BxNxf16>, tensor<1xf16>) -> tensor<Bxf16> |
| // %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0} |
| // : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16> |
| // |
| // %exp = "xla_hlo.exp"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16> |
| // |
| // // Cast to f32 to avoid precision loss in summation. |
| // %exp_f32 = "xla_hlo.convert"(%exp) : (tensor<BxNxbf16>) -> tensor<BxNxf32> |
| // %zero = constant dense<0.000000e+00> : tensor<f32> |
| // %sum = "xla_hlo.reduce"(%exp, %zero) ["xla_hlo.add"] {dimensions = 1} |
| // : (tensor<BxNxf32>, tensor<1xf32>) -> tensor<Bxf32> |
| // |
| // %sum_f16 = "xla_hlo.convert"(%sum) : (tensor<BxNxbf32>) -> tensor<BxNxf16> |
| // |
| // // Softmax computation: |
| // %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0} |
| // : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16> |
| // |
| // TODO(hinsu): Use tf.Max and tf.Sum instead of lowering directly to xla. |
| template <typename OpTy, bool use_log = true> |
| class ConvertSoftmaxOp : public OpRewritePattern<OpTy> { |
| public: |
| explicit ConvertSoftmaxOp(MLIRContext *context) |
| : OpRewritePattern<OpTy>(context, 1) {} |
| |
| PatternMatchResult matchAndRewrite(OpTy op, |
| PatternRewriter &rewriter) const override { |
| Value *logits = op.logits(); |
| |
| // Softmax converter requires ranked type because the XLA reduce ops used |
| // while lowering requires dimensions attribute to reduce along. |
| RankedTensorType type = logits->getType().dyn_cast<RankedTensorType>(); |
| if (!type) return Pattern::matchFailure(); |
| |
| auto loc = op.getLoc(); |
| |
| int rank = type.getRank(); |
| |
| // Note that the TensorFlow Softmax op verifies that the input rank is |
| // greater than or equal to one so both of the following sequences are |
| // valid. |
| auto batch_dims = GetI64ElementsAttrForSeq(0, rank - 1, &rewriter); |
| auto reduce_dim = GetI64ElementsAttrForSeq(rank - 1, rank, &rewriter); |
| |
| // Exponential of input values and then their sum can be very large here. |
| // Division with large denominator is numerically unstable. To improve |
| // numerical stability, subtract each batch with their max element so that |
| // the maximum input value is zero. It can be shown that softmax computed |
| // after adding or subtracting all inputs in a batch using a common value |
| // gives mathematically equivalent result. |
| Type element_type = type.getElementType(); |
| ArrayRef<int64_t> reduce_shape = type.getShape().drop_back(); |
| RankedTensorType reduce_out_type = |
| rewriter.getTensorType(reduce_shape, element_type); |
| auto init = GetMinValueForType(element_type, loc, &rewriter); |
| auto max_logits = rewriter.create<xla_hlo::ReduceOp>( |
| loc, reduce_out_type, logits, init.getResult(), reduce_dim); |
| BuildReduceBody<xla_hlo::MaxOp>(element_type, &max_logits.body(), |
| &rewriter); |
| auto shifted_logits = rewriter.create<xla_hlo::SubOp>( |
| loc, type, logits, max_logits.getResult(0), batch_dims); |
| |
| // Exponentiate the inputs. |
| Value *exp = rewriter.create<xla_hlo::ExpOp>(loc, type, shifted_logits); |
| |
| // Cast the exponentials to the appropriate accumulation type to avoid |
| // precision loss during summation. |
| Type sum_element_type = GetAccumulationType(element_type); |
| Type sum_type = rewriter.getTensorType(type.getShape(), sum_element_type); |
| auto casted_exp = rewriter.create<xla_hlo::ConvertOp>(loc, sum_type, exp); |
| |
| // Compute summation of the exponentials. |
| init = rewriter.create<xla_hlo::ConstOp>( |
| loc, DenseElementsAttr::get(rewriter.getTensorType({}, element_type), |
| rewriter.getZeroAttr(element_type))); |
| Type sum_out_type = rewriter.getTensorType(reduce_shape, sum_element_type); |
| auto exp_sum = rewriter.create<xla_hlo::ReduceOp>( |
| loc, sum_out_type, casted_exp.getResult(), init.getResult(), |
| reduce_dim); |
| BuildReduceBody<xla_hlo::AddOp>(element_type, &exp_sum.body(), &rewriter); |
| Value *sum = exp_sum.getResult(0); |
| |
| // Convert the summation result back to the original element type and divide |
| // exponentials by the summations. |
| sum = rewriter.create<xla_hlo::ConvertOp>(loc, reduce_out_type, sum); |
| |
| if (use_log) { |
| Value *log = rewriter.create<xla_hlo::LogOp>(loc, reduce_out_type, sum); |
| rewriter.replaceOpWithNewOp<xla_hlo::SubOp>( |
| op, op.getType(), shifted_logits, log, batch_dims); |
| } else { |
| rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op, op.getType(), exp, sum, |
| batch_dims); |
| } |
| |
| return Pattern::matchSuccess(); |
| } |
| }; |
| |
| // Converts StridedSlice op to HLO Slice op along with Reverse op to handle |
| // negative strides and Reshape op to update the output shape. Indices and |
| // strides operands are converted to attributes with non-negative indexing. |
| // |
| // For example with an op like following, |
| // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} |
| // : tensor<AxBxf32> -> tensor<Pxf32> |
| // |
| // Output would be: |
| // %reversed = "xla_hlo.Reverse" (%input) {dimensions = ...} |
| // %sliced = "xla_hlo.Slice" (%input) |
| // {start_indices = ..., limit_indices = ..., strides = ...} |
| // %output = "xla_hlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor<Pxf32> |
| // |
| class ConvertStridedSliceOp : public OpRewritePattern<TF::StridedSliceOp> { |
| public: |
| explicit ConvertStridedSliceOp(MLIRContext *context) |
| : OpRewritePattern<TF::StridedSliceOp>(context, 1) {} |
| |
| PatternMatchResult matchAndRewrite(TF::StridedSliceOp op, |
| PatternRewriter &rewriter) const override { |
| // Input shape needs to be static to convert negative indices in TensorFlow |
| // to absolute indices required by HLO. |
| // |
| // TODO(hinsu): Relax this constraint for ops without negative indices and |
| // strides. |
| auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); |
| if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); |
| ArrayRef<int64_t> input_shape = input_ty.getShape(); |
| |
| // Output shape needs to be static to apply 'new_axis_mask' or |
| // 'shrink_axis_mask' by reshaping tensor after slice. |
| // |
| // TODO(hinsu): Relax this constraint for ops without the above masks. |
| auto result_ty = op.getType().dyn_cast<RankedTensorType>(); |
| if (!result_ty || !result_ty.hasStaticShape()) return matchFailure(); |
| |
| // TODO(hinsu): Support non-zero mask values. Currently only |
| // 'shrink_axis_mask' is supported. |
| for (StringRef mask : |
| {"begin_mask", "end_mask", "ellipsis_mask", "new_axis_mask"}) { |
| auto attr = op.getAttrOfType<IntegerAttr>(mask); |
| if (attr && attr.getValue() != 0) return matchFailure(); |
| } |
| |
| // TODO(hinsu): Support lowering for ops with dynamic begin and end values |
| // when it is possible to derive indices based on mask attributes. |
| DenseIntElementsAttr begin_indices, end_indices, strides; |
| if (!matchPattern(op.begin(), m_Constant(&begin_indices)) || |
| !matchPattern(op.end(), m_Constant(&end_indices)) || |
| !matchPattern(op.strides(), m_Constant(&strides))) |
| return matchFailure(); |
| |
| SmallVector<int64_t, 4> hlo_begin_indices, hlo_end_indices, hlo_strides, |
| dims_to_reverse; |
| int64_t input_rank = input_ty.getRank(); |
| for (auto *vec : {&hlo_begin_indices, &hlo_end_indices, &hlo_strides}) { |
| vec->reserve(input_rank); |
| } |
| |
| int64_t indices_elements = begin_indices.getNumElements(); |
| if (input_rank < indices_elements) return matchFailure(); |
| |
| // Convert from TensorFlow negative or out of range indices and strides |
| // values to legal HLO Slice attributes. |
| for (int i = 0, e = indices_elements; i != e; i++) { |
| int64_t begin = begin_indices.getValue<IntegerAttr>(i).getInt(); |
| int64_t end = end_indices.getValue<IntegerAttr>(i).getInt(); |
| int64_t stride = strides.getValue<IntegerAttr>(i).getInt(); |
| |
| if (begin < 0) begin = input_shape[i] + begin; |
| if (end < 0) end = input_shape[i] + end; |
| |
| if (stride < 0) { |
| // Negative stride means that the output values are computed starting |
| // from end until begin. Mark the dimension for reversal before slice |
| // and compute indices for the reversed input. |
| dims_to_reverse.push_back(i); |
| begin = (input_shape[i] - 1) - begin; |
| end = (input_shape[i] - 1) - end; |
| stride = -stride; |
| } |
| |
| // Unlike TensorFlow, HLO requires begin and end values to be within |
| // range. |
| begin = std::max(int64_t(0), begin); |
| end = std::max(begin, end); |
| end = std::min(end, input_shape[i]); |
| |
| hlo_begin_indices.push_back(begin); |
| hlo_end_indices.push_back(end); |
| hlo_strides.push_back(stride); |
| } |
| |
| Location loc = op.getLoc(); |
| auto reversed = rewriter.create<xla_hlo::ReverseOp>( |
| loc, input_ty, op.input(), |
| GetI64ElementsAttr(dims_to_reverse, &rewriter)); |
| auto sliced = rewriter.create<xla_hlo::SliceOp>( |
| loc, reversed.getResult(), |
| GetI64ElementsAttr(hlo_begin_indices, &rewriter), |
| GetI64ElementsAttr(hlo_end_indices, &rewriter), |
| GetI64ElementsAttr(hlo_strides, &rewriter)); |
| |
| // Reshape slice result so that the shape is updated depending on |
| // 'new_axis_mask' or 'shrink_axis_mask' attributes. |
| rewriter.replaceOpWithNewOp<xla_hlo::ReshapeOp>(op, op.getType(), sliced); |
| return matchSuccess(); |
| } |
| }; |
| |
| /// Converts a generic OpTy tensorflow op to a xla_hlo.reduce op over |
| /// ReductionOp. |
| /// `is_accumulation` controls whether it uses higher precision for the actual |
| /// reduction. This is set to false for ops like max where there is no precision |
| /// concerns. |
| template <typename Derived, typename OpTy, typename ReductionOp, |
| bool is_accumulation = true> |
| class GenericConvertReductionOp : public OpRewritePattern<OpTy> { |
| using OpRewritePattern<OpTy>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(OpTy op, |
| PatternRewriter &rewriter) const override { |
| // TODO(b/141785544): Update this to not require static shapes. |
| // Input shape needs to be static to convert negative indices in TensorFlow |
| // to absolute indices required by HLO. |
| auto input_ty = op.input()->getType().template dyn_cast<RankedTensorType>(); |
| if (!input_ty) return this->matchFailure(); |
| ArrayRef<int64_t> input_shape = input_ty.getShape(); |
| |
| DenseIntElementsAttr dimensions; |
| if (!matchPattern(op.reduction_indices(), m_Constant(&dimensions)) || |
| dimensions.getType().getRank() != 1) |
| return this->matchFailure(); |
| |
| // Build the final shape from input_shape and dimensions using a bitmap |
| // to mark the reduced dimensions. |
| SmallVector<bool, 4> reduced_dimensions_bitmap(input_shape.size(), false); |
| SmallVector<int64_t, 4> xla_dimensions; |
| for (APInt index_raw : dimensions.getValues<APInt>()) { |
| int64_t index = index_raw.getSExtValue(); |
| int64_t rank = input_shape.size(); |
| if ((index < -rank || index >= rank)) return this->matchFailure(); |
| index = (index + rank) % rank; |
| reduced_dimensions_bitmap[index] = true; |
| xla_dimensions.push_back(index); |
| } |
| SmallVector<int64_t, 4> reduced_shape; |
| reduced_shape.reserve(input_shape.size()); |
| for (size_t i = 0; i < input_shape.size(); ++i) { |
| if (!reduced_dimensions_bitmap[i]) { |
| // If we are not reducing along dimension i. |
| int64_t dim = input_shape[i]; |
| reduced_shape.push_back(dim); |
| } |
| } |
| |
| Location loc = op.getLoc(); |
| Type element_type = input_ty.getElementType(); |
| // Convert to an accumulation type to not lose precision when doing |
| // repeated arithmetic operations. |
| Type reduce_element_type = |
| is_accumulation ? GetAccumulationType(element_type) : element_type; |
| auto casted_input = rewriter.create<xla_hlo::ConvertOp>( |
| loc, rewriter.getTensorType(input_shape, reduce_element_type), |
| op.input()); |
| |
| // Each reduction op can have a different initial value. |
| Value *init = Derived::GetInitialValue(reduce_element_type, loc, rewriter); |
| |
| Type reduced_out_type = |
| rewriter.getTensorType(reduced_shape, reduce_element_type); |
| // TODO(hinsu): Infer reduced_out_type. |
| auto reduction = rewriter.create<xla_hlo::ReduceOp>( |
| loc, reduced_out_type, casted_input.getResult(), init, |
| GetI64ElementsAttr(xla_dimensions, &rewriter)); |
| BuildReduceBody<ReductionOp>(reduce_element_type, &reduction.body(), |
| &rewriter); |
| Value *result = reduction.getResult(0); |
| |
| // The mean op needs to divide by the product of the reduced dimensions. |
| if (std::is_same<OpTy, TF::MeanOp>::value) { |
| int64_t divisor_count = 1; |
| for (size_t i = 0; i < input_shape.size(); ++i) { |
| if (reduced_dimensions_bitmap[i]) { |
| if (TensorType::isDynamic(input_shape[i])) { |
| return this->matchFailure(); |
| } |
| divisor_count *= input_shape[i]; |
| } |
| } |
| auto divisor = |
| GetScalarForType(reduce_element_type, loc, divisor_count, &rewriter); |
| result = rewriter.create<xla_hlo::DivOp>( |
| loc, reduced_out_type, result, divisor.getResult(), |
| /* broadcast_dimensions= */ DenseIntElementsAttr()); |
| } |
| |
| Type reduced_final_type = |
| rewriter.getTensorType(reduced_shape, element_type); |
| result = |
| rewriter.create<xla_hlo::ConvertOp>(loc, reduced_final_type, result); |
| |
| // Need to reshape back after the reduction if we're keeping the reduced |
| // dimensions. |
| if (op.keep_dims()) { |
| result = rewriter.create<xla_hlo::ReshapeOp>(loc, op.getType(), result); |
| } |
| rewriter.replaceOp(op, {result}, {op.reduction_indices()}); |
| |
| return this->matchSuccess(); |
| } |
| }; |
| |
| // Converts Mean op to HLO Reduce op. |
| // |
| // %init = constant dense<...> : tensor<T> |
| // %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"] |
| // {dimensions = ...} |
| // %divisor = constant dense<...> : tensor<T> |
| // %mean = "xla_hlo.div"(%sum, %divisor) |
| class ConvertMeanOp |
| : public GenericConvertReductionOp<ConvertMeanOp, TF::MeanOp, |
| xla_hlo::AddOp> { |
| public: |
| using GenericConvertReductionOp::GenericConvertReductionOp; |
| |
| static Value *GetInitialValue(Type reduce_element_type, Location loc, |
| PatternRewriter &rewriter) { |
| return GetScalarForType(reduce_element_type, loc, 0, &rewriter); |
| } |
| }; |
| |
| // Converts Sum op to HLO Reduce op. |
| // |
| // %init = constant dense<...> : tensor<T> |
| // %sum = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.add"] |
| // {dimensions = ...} |
| class ConvertSumOp : public GenericConvertReductionOp<ConvertSumOp, TF::SumOp, |
| xla_hlo::AddOp> { |
| public: |
| using GenericConvertReductionOp::GenericConvertReductionOp; |
| |
| static Value *GetInitialValue(Type reduce_element_type, Location loc, |
| PatternRewriter &rewriter) { |
| return GetScalarForType(reduce_element_type, loc, 0, &rewriter); |
| } |
| }; |
| |
| // Converts Max op to HLO Reduce op. |
| // |
| // %init = constant dense<...> : tensor<T> |
| // %max = "xla_hlo.reduce"(%inp, %init) ["xla_hlo.max"] |
| // {dimensions = ...} |
| class ConvertMaxOp |
| : public GenericConvertReductionOp<ConvertMaxOp, TF::MaxOp, xla_hlo::MaxOp, |
| /* is_accumulation= */ false> { |
| public: |
| using GenericConvertReductionOp::GenericConvertReductionOp; |
| |
| static Value *GetInitialValue(Type reduce_element_type, Location loc, |
| PatternRewriter &rewriter) { |
| return GetMinValueForType(reduce_element_type, loc, &rewriter); |
| } |
| }; |
| |
| // Converts tensorflow ArgMin or ArgMax op to xla_hlo operations that perform |
| // a reduction on the original input and the corresponding index. The reduction |
| // sub-computation selects the max (or min) value and the index for the value. |
| // Derived: is the resulting derived class of this class. |
| // OpTy: is TF::ArgMaxOp or TF::ArgMinOp. |
| template <typename Derived, typename OpTy> |
| class ConvertArgMinMaxOp : public OpRewritePattern<OpTy> { |
| using OpRewritePattern<OpTy>::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(OpTy op, |
| PatternRewriter &rewriter) const override { |
| RankedTensorType input_type = |
| op.input()->getType().template dyn_cast<RankedTensorType>(); |
| if (!input_type) { |
| return this->matchFailure(); |
| } |
| |
| Type input_element_type = input_type.getElementType(); |
| // TODO(bixia): Clarify whether tf.ArgMax supports complex data types. If |
| // tf.ArgMax doesn't support complex data types, this check can be removed. |
| if (!input_element_type.isIntOrFloat()) return this->matchFailure(); |
| |
| Location loc = op.getLoc(); |
| Value *init_value = |
| Derived::GetInitialValue(input_element_type, loc, rewriter); |
| |
| RankedTensorType output_type = |
| op.output()->getType().template dyn_cast<RankedTensorType>(); |
| if (!output_type) { |
| return this->matchFailure(); |
| } |
| |
| Type index_element_type = output_type.getElementType(); |
| Value *index_init_value = |
| GetScalarForType(index_element_type, loc, 0, &rewriter); |
| |
| RankedTensorType index_type = |
| rewriter.getTensorType(input_type.getShape(), index_element_type); |
| |
| llvm::Optional<int64_t> optional_axis = |
| GetIntegerHLOAxisFromTFAxis(op.dimension(), input_type.getRank()); |
| if (!optional_axis.hasValue()) { |
| return this->matchFailure(); |
| } |
| int64_t axis = optional_axis.getValue(); |
| |
| IntegerAttr iota_dimension = |
| IntegerAttr::get(rewriter.getIntegerType(64), axis); |
| Value *index_values = |
| rewriter.create<xla_hlo::IotaOp>(loc, index_type, iota_dimension); |
| |
| std::vector<int64_t> dimensions = input_type.getShape(); |
| dimensions.erase(dimensions.begin() + axis); |
| ArrayRef<int64_t> reduction_result_shape(dimensions); |
| |
| Type input_reduction_result_type = rewriter.getTensorType( |
| reduction_result_shape, input_type.getElementType()); |
| Type index_reduction_result_type = rewriter.getTensorType( |
| reduction_result_shape, index_type.getElementType()); |
| |
| Type result_types[] = {input_reduction_result_type, |
| index_reduction_result_type}; |
| Value *operands[] = {op.input(), index_values}; |
| Value *init_values[] = {init_value, index_init_value}; |
| DenseIntElementsAttr reduction_dimensions = |
| GetI64ElementsAttr({axis}, &rewriter); |
| |
| auto reduction = rewriter.create<xla_hlo::ReduceOp>( |
| loc, llvm::ArrayRef<Type>(result_types), |
| llvm::ArrayRef<Value *>(operands), llvm::ArrayRef<Value *>(init_values), |
| reduction_dimensions); |
| StringRef direction = Derived::GetDirection(); |
| BuildArgMinMaxReductionBody(input_element_type, index_element_type, |
| direction, &reduction.body(), &rewriter); |
| |
| rewriter.replaceOp(op, {reduction.getResult(1)}); |
| return this->matchSuccess(); |
| } |
| }; |
| |
| // Converts tensorflow ArgMax op to xla_hlo operations. The actual |
| // implementation is in class ConvertArgMinMaxOp: |
| // |
| // %init_index = constant dense<...> : tensor<T> |
| // %init = constant dense<...> : tensor<T> |
| // %reduce = "xla_hlo.reduce"(%selected_input, %select_index, %init, |
| // %init_index) ["xla_hlo.arg_max"] |
| class ConvertArgMaxOp |
| : public ConvertArgMinMaxOp<ConvertArgMaxOp, TF::ArgMaxOp> { |
| public: |
| using ConvertArgMinMaxOp::ConvertArgMinMaxOp; |
| |
| static Value *GetInitialValue(Type reduce_element_type, Location loc, |
| PatternRewriter &rewriter) { |
| return GetMinValueForType(reduce_element_type, loc, &rewriter); |
| } |
| |
| static StringRef GetDirection() { return "GT"; } |
| }; |
| |
| // Converts Tile op to HLO BroadcastInDim and Reshape ops. |
| // For shape [S1, S2] and multiples [M1, M2], |
| // MS1 = M1 * S1; MS2 = M2 * S2 |
| // |
| // %broadcast = xla_hlo.broadcast_in_dim(%input) { |
| // broadcast_dimensions = [0, 2] |
| // } |
| // %result = "xla_hlo.reshape"(%broadcast) : (tensor<S1xM1xS2xM2xf32>) |
| // -> tensor<MS1xMS2xf32> |
| class ConvertTileOp : public OpRewritePattern<TF::TileOp> { |
| public: |
| using OpRewritePattern::OpRewritePattern; |
| |
| PatternMatchResult matchAndRewrite(TF::TileOp op, |
| PatternRewriter &rewriter) const override { |
| auto input_ty = op.input()->getType().dyn_cast<RankedTensorType>(); |
| if (!input_ty || !input_ty.hasStaticShape()) return matchFailure(); |
| ArrayRef<int64_t> input_shape = input_ty.getShape(); |
| Type element_type = input_ty.getElementType(); |
| |
| DenseIntElementsAttr multiples; |
| if (!matchPattern(op.multiples(), m_Constant(&multiples)) || |
| multiples.getType().getRank() != 1) |
| return matchFailure(); |
| |
| if (multiples.getNumElements() != input_shape.size()) return matchFailure(); |
| |
| SmallVector<int64_t, 8> broadcasted_shape; |
| SmallVector<int64_t, 4> broadcast_dimensions; |
| broadcasted_shape.reserve(input_shape.size() * 2); |
| broadcast_dimensions.reserve(input_shape.size()); |
| for (auto multiple_and_input : |
| llvm::zip(multiples.getValues<APInt>(), input_shape)) { |
| int64_t multiple = std::get<0>(multiple_and_input).getSExtValue(); |
| int64_t input_size = std::get<1>(multiple_and_input); |
| |
| if (multiple < 0) return matchFailure(); |
| |
| // Line input up with the next dimension in broadcasted_shape |
| // when broadcasting. |
| broadcast_dimensions.push_back(broadcasted_shape.size()); |
| int64_t output_size = input_size * multiple; |
| if (input_size == 1 || multiple == 1) { |
| // Special case for when normal broadcasting will just work. |
| broadcasted_shape.push_back(output_size); |
| } else { |
| // Tiling will happen for this dimension during the ReshapeOp below. |
| broadcasted_shape.push_back(input_size); |
| broadcasted_shape.push_back(multiple); |
| } |
| } |
| Location loc = op.getLoc(); |
| Type broadcasted_type = |
| rewriter.getTensorType(broadcasted_shape, element_type); |
| Type output_type = op.getType(); |
| |
| Value *result = rewriter.create<xla_hlo::BroadcastInDimOp>( |
| loc, broadcasted_type, op.input(), |
| GetI64ElementsAttr(broadcast_dimensions, &rewriter)); |
| |
| if (output_type != broadcasted_type) { |
| result = rewriter.create<xla_hlo::ReshapeOp>(loc, output_type, result); |
| } |
| |
| rewriter.replaceOp(op, {result}, {op.multiples()}); |
| |
| return matchSuccess(); |
| } |
| }; |
| |
| #include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc" |
| } // end anonymous namespace |
| } // end namespace xla |
| } // end namespace mlir |
| |
| LogicalResult mlir::xla_hlo::legalizeTF(Operation *op) { |
| MLIRContext *context = op->getContext(); |
| |
| // Add lowering patterns to the list. |
| OwningRewritePatternList patterns; |
| xla::populateWithGenerated(context, &patterns); |
| |
| // Add patterns that lower some of the high level TensorFlow ops to lower |
| // level TensorFlow ops. So, we don't have to target all the TensorFlow ops |
| // here for lowering to HLO. |
| mlir::TF::PopulateLoweringTFPatterns(context, &patterns); |
| patterns.insert<mlir::xla::ConvertArgMaxOp, mlir::xla::ConvertMaxPoolOp, |
| mlir::xla::ConvertSigmoidOp, |
| mlir::xla::ConvertSoftmaxOp<TF::LogSoftmaxOp, true>, |
| mlir::xla::ConvertSoftmaxOp<TF::SoftmaxOp, false>, |
| mlir::xla::ConvertStridedSliceOp, mlir::xla::ConvertMeanOp, |
| mlir::xla::ConvertSumOp, mlir::xla::ConvertMaxOp, |
| mlir::xla::ConvertTileOp>(op->getContext()); |
| |
| ConversionTarget target(*context); |
| target.addLegalDialect<XlaHloDialect>(); |
| |
| return applyPartialConversion(op, target, patterns); |
| } |
| |
| /// Performs the lowering to XLA dialect. |
| void LegalizeTF::runOnFunction() { |
| if (failed(mlir::xla_hlo::legalizeTF(getFunction()))) signalPassFailure(); |
| } |
| |
| static PassRegistration<LegalizeTF> pass( |
| "xla-legalize-tf", "Legalize from TensorFlow to the XLA dialect"); |