Add XLA lowering pattern for TensorFlow Softmax op
PiperOrigin-RevId: 267740332
diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
index 1b020c0..3b94f7b 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir
@@ -216,6 +216,64 @@
}
//===----------------------------------------------------------------------===//
+// Softmax op legalizations.
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @simple_softmax
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>)
+func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
+ // CHECK: %[[NEG_INF:.*]] = constant dense<0xFF800000> : tensor<f32>
+ // CHECK: %[[ZERO:.*]] = constant dense<0.000000e+00> : tensor<f32>
+
+ // Verify reduce op for max computation and its body.
+ // CHECK: %[[MAX:.*]] = "xla_hlo.reduce"(%[[ARG0]], %[[NEG_INF]])
+ // CHECK: xla_hlo.max
+ // CHECK: "xla_hlo.return"
+ // CHECK: {dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf32>, tensor<f32>) -> tensor<2xf32>
+
+ // CHECK: %[[SHIFTED_INP:.*]] = "xla_hlo.sub"(%[[ARG0]], %[[MAX]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
+ // CHECK: %[[EXP:.*]] = "xla_hlo.exp"(%[[SHIFTED_INP]])
+
+ // Verify reduce op for summation and its body.
+ // CHECK: %[[SUM:.*]] = "xla_hlo.reduce"(%[[EXP]], %[[ZERO]])
+ // CHECK: xla_hlo.add
+ // CHECK: "xla_hlo.return"
+ // CHECK: {dimensions = dense<1> : tensor<1xi64>}
+
+ // CHECK: %[[RESULT:.*]] = "xla_hlo.div"(%[[EXP]], %[[SUM]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
+ // return %[[RESULT]]
+
+ %0 = "tf.Softmax"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32>
+ return %0: tensor<2x3xf32>
+}
+
+// CHECK-LABEL: bf16_softmax
+func @bf16_softmax(%arg0: tensor<2x3xbf16>) -> tensor<2x3xbf16> {
+ // Verify that conversion to f32 and then back to bf16 are introduced.
+
+ // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2x3xbf16>) -> tensor<2x3xf32>
+ // CHECK: "xla_hlo.convert"({{.*}}) : (tensor<2xf32>) -> tensor<2xbf16>
+
+ %0 = "tf.Softmax"(%arg0) : (tensor<2x3xbf16>) -> tensor<2x3xbf16>
+ return %0: tensor<2x3xbf16>
+}
+
+// CHECK-LABEL: rank4_softmax
+func @rank4_softmax(%arg0: tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16> {
+ // Verify that reduce op dimensions and broadcast dimensions are correct.
+
+ // CHECK: "xla_hlo.reduce"
+ // CHECK: dimensions = dense<3>
+
+ // CHECK: "xla_hlo.reduce"
+ // CHECK: dimensions = dense<3>
+
+ // CHECK: "xla_hlo.div"{{.*}} {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
+ %0 = "tf.Softmax"(%arg0) : (tensor<2x3x4x5xf16>) -> tensor<2x3x4x5xf16>
+ return %0: tensor<2x3x4x5xf16>
+}
+
+//===----------------------------------------------------------------------===//
// Unary op legalizations.
//===----------------------------------------------------------------------===//
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
index 1e2b403..b715b52 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc
@@ -19,8 +19,11 @@
#include "mlir/Dialect/StandardOps/Ops.h" // TF:local_config_mlir
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
+#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
+#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/IR/PatternMatch.h" // TF:local_config_mlir
+#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Pass/Pass.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
@@ -168,6 +171,34 @@
b.getTensorType({minRank}, b.getIntegerType(64)), broadcastDimensions);
}
+//===----------------------------------------------------------------------===//
+// Softmax op utilities.
+//===----------------------------------------------------------------------===//
+
+// Returns a 1-d i64 elements attribute populated with numbers from start to
+// end, excluding.
+static ElementsAttr GetI64AttrForSeq(int start, int end, Builder *builder) {
+ int size = end - start;
+
+ SmallVector<int64_t, 4> vals;
+ vals.resize(size);
+ std::iota(vals.begin(), vals.end(), start);
+
+ TensorType ty = builder->getTensorType({size}, builder->getIntegerType(64));
+ return DenseIntElementsAttr::get<int64_t>(ty, vals);
+}
+
+// Returns the type to use for accumulating the given type.
+static Type GetAccumulationType(Type ty) {
+ // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
+ // repeated floating point additions.
+ return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
+}
+
+//===----------------------------------------------------------------------===//
+// Op converters.
+//===----------------------------------------------------------------------===//
+
namespace mlir {
namespace xla {
namespace {
@@ -217,6 +248,101 @@
}
};
+// Converts Softmax op to HLO ops computing softmax with the following formula:
+//
+// softmax = div(exp(logits), sum(exp(logits)))
+//
+// Sample result with 2-d f16 inputs with B batches of with N elements each.
+//
+// // Subtract each element by their batches' max to improve numerical
+// // stability.
+// %neg_infinity = constant dense<0xFF800000> : tensor<f16>
+// %max = "xla_hlo.reduce"(%input, %neg_infinity) ["xla_hlo.max"]
+// {dimensions = 1}
+// : (tensor<BxNxf16>, tensor<1xf16>) -> tensor<Bxf16>
+// %sub = "xla_hlo.sub"(%inp, %max) {broadcast_dimensions = 0}
+// : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
+//
+// %exp = "xla_hlo.exp"(%sub) : (tensor<BxNxf16>) -> tensor<BxNxf16>
+//
+// // Cast to f32 to avoid precision loss in summation.
+// %exp_f32 = "xla_hlo.convert"(%exp) : (tensor<BxNxbf16>) -> tensor<BxNxf32>
+// %zero = constant dense<0.000000e+00> : tensor<f32>
+// %sum = "xla_hlo.reduce"(%exp, %zero) ["xla_hlo.add"] {dimensions = 1}
+// : (tensor<BxNxf32>, tensor<1xf32>) -> tensor<Bxf32>
+//
+// %sum_f16 = "xla_hlo.convert"(%sum) : (tensor<BxNxbf32>) -> tensor<BxNxf16>
+// %softmax = "xla_hlo.div"(%exp, %sum_f16) {broadcast_dimensions = 0}
+// : (tensor<BxNxf16>, tensor<Bxf16>) -> tensor<BxNxf16>
+//
+class ConvertSoftmaxOp : public OpRewritePattern<TF::SoftmaxOp> {
+ public:
+ explicit ConvertSoftmaxOp(MLIRContext *context)
+ : OpRewritePattern<TF::SoftmaxOp>(context, 1) {}
+
+ PatternMatchResult matchAndRewrite(TF::SoftmaxOp op,
+ PatternRewriter &rewriter) const override {
+ Value *logits = op.logits();
+
+ // Softmax converter requires ranked type because the XLA reduce ops used
+ // while lowering requires dimensions attribute to reduce along.
+ RankedTensorType type = logits->getType().dyn_cast<RankedTensorType>();
+ if (!type) return matchFailure();
+ int rank = type.getRank();
+
+ // Note that the TensorFlow Softmax op verifies that the input rank is
+ // greater than or equal to one so both of the following sequences are
+ // valid.
+ ElementsAttr batch_dims = GetI64AttrForSeq(0, rank - 1, &rewriter);
+ ElementsAttr reduce_dim = GetI64AttrForSeq(rank - 1, rank, &rewriter);
+ Location loc = op.getLoc();
+
+ // Exponential of input values and then their sum can be very large here.
+ // Division with large denominator is numerically unstable. To improve
+ // numerical stability, subtract each batch with their max element so that
+ // the maximum input value is zero. It can be shown that softmax computed
+ // after adding or subtracting all inputs in a batch using a common value
+ // gives mathematically equivalent result.
+ Type element_type = type.getElementType();
+ ArrayRef<int64_t> reduce_shape = type.getShape().drop_back();
+ RankedTensorType reduce_out_type =
+ rewriter.getTensorType(reduce_shape, element_type);
+ auto init = GetMinValueForType(element_type, loc, &rewriter);
+ auto max_logits = rewriter.create<xla_hlo::ReduceOp>(
+ loc, reduce_out_type, ArrayRef<Value *>{logits, init}, reduce_dim);
+ BuildReduceBody<xla_hlo::MaxOp>(element_type, &max_logits.body(),
+ &rewriter);
+ auto shifted_logits = rewriter.create<xla_hlo::SubOp>(
+ loc, type, logits, max_logits.getResult(0), batch_dims);
+
+ // Exponentiate the inputs.
+ Value *exp = rewriter.create<xla_hlo::ExpOp>(loc, type, shifted_logits);
+
+ // Cast the exponentials to the appropriate accumulation type to avoid
+ // precision loss during summation.
+ Type sum_element_type = GetAccumulationType(element_type);
+ Type sum_type = rewriter.getTensorType(type.getShape(), sum_element_type);
+ auto casted_exp = rewriter.create<xla_hlo::ConvertOp>(loc, sum_type, exp);
+
+ // Compute summation of the exponentials.
+ init = rewriter.create<ConstantOp>(
+ loc, DenseElementsAttr::get(rewriter.getTensorType({}, element_type),
+ rewriter.getZeroAttr(element_type)));
+ Type sum_out_type = rewriter.getTensorType(reduce_shape, sum_element_type);
+ auto exp_sum = rewriter.create<xla_hlo::ReduceOp>(
+ loc, sum_out_type, ArrayRef<Value *>{casted_exp, init}, reduce_dim);
+ BuildReduceBody<xla_hlo::AddOp>(element_type, &exp_sum.body(), &rewriter);
+ Value *sum = exp_sum.getResult(0);
+
+ // Convert the summation result back to the original element type and divide
+ // exponentials by the summations.
+ sum = rewriter.create<xla_hlo::ConvertOp>(loc, reduce_out_type, sum);
+ rewriter.replaceOpWithNewOp<xla_hlo::DivOp>(op.getOperation(), op.getType(),
+ exp, sum, batch_dims);
+ return matchSuccess();
+ }
+};
+
#include "tensorflow/compiler/mlir/xla/transforms/generated_legalize_tf.inc"
} // end anonymous namespace
} // end namespace xla
@@ -227,6 +353,7 @@
OwningRewritePatternList patterns;
xla::populateWithGenerated(op->getContext(), &patterns);
patterns.insert<mlir::xla::ConvertMaxPoolOp>(op->getContext());
+ patterns.insert<mlir::xla::ConvertSoftmaxOp>(op->getContext());
// Recursively applies rewrite patterns to nested operations.
applyPatternsGreedily(op, patterns);